From 7075f8d0e28d3f62bc5ea9bc15489bf3ab25c987 Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Thu, 17 Aug 2023 21:11:09 +0200 Subject: [PATCH 01/15] ReBRAC finetune --- algorithms/finetune/rebrac.py | 1028 +++++++++++++++++ .../rebrac/antmaze/large_diverse_v2.yaml | 35 + .../rebrac/antmaze/large_play_v2.yaml | 35 + .../rebrac/antmaze/medium_diverse_v2.yaml | 35 + .../rebrac/antmaze/medium_play_v2.yaml | 35 + .../rebrac/antmaze/umaze_diverse_v2.yaml | 35 + configs/finetune/rebrac/antmaze/umaze_v2.yaml | 36 + configs/finetune/rebrac/door/cloned_v1.yaml | 35 + configs/finetune/rebrac/hammer/cloned_v1.yaml | 35 + configs/finetune/rebrac/pen/cloned_v1.yaml | 35 + .../finetune/rebrac/relocate/cloned_v1.yaml | 35 + 11 files changed, 1379 insertions(+) create mode 100644 algorithms/finetune/rebrac.py create mode 100644 configs/finetune/rebrac/antmaze/large_diverse_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/large_play_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/medium_play_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml create mode 100644 configs/finetune/rebrac/antmaze/umaze_v2.yaml create mode 100644 configs/finetune/rebrac/door/cloned_v1.yaml create mode 100644 configs/finetune/rebrac/hammer/cloned_v1.yaml create mode 100644 configs/finetune/rebrac/pen/cloned_v1.yaml create mode 100644 configs/finetune/rebrac/relocate/cloned_v1.yaml diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py new file mode 100644 index 00000000..d29822e6 --- /dev/null +++ b/algorithms/finetune/rebrac.py @@ -0,0 +1,1028 @@ +import os + +os.environ["TF_CUDNN_DETERMINISTIC"] = "1" + +import math +import wandb +import uuid +from copy import deepcopy +import pyrallis +import random + +import chex +# import d4rl # noqa +# import gym +import jax +import numpy as np +import optax +import tqdm +from typing import Sequence, Union + +from functools import partial +from dataclasses import dataclass, asdict +from flax.core import FrozenDict +from typing import Dict, Tuple, Any, Callable +import flax.linen as nn +import jax.numpy as jnp +from tqdm.auto import trange + +from flax.training.train_state import TrainState + +ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") + +default_kernel_init = nn.initializers.lecun_normal() +default_bias_init = nn.initializers.zeros + + +@dataclass +class Config: + # wandb params + project: str = "ReBRAC" + group: str = "rebrac-finetune" + name: str = "rebrac-finetune" + # model params + actor_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + hidden_dim: int = 256 + actor_n_hiddens: int = 3 + critic_n_hiddens: int = 3 + replay_buffer_size: int = 2_000_000 + mixing_ratio: float = 0.5 + gamma: float = 0.99 + tau: float = 5e-3 + actor_bc_coef: float = 1.0 + critic_bc_coef: float = 1.0 + actor_ln: bool = False + critic_ln: bool = True + policy_noise: float = 0.2 + noise_clip: float = 0.5 + expl_noise: float = 0.0 + policy_freq: int = 2 + normalize_q: bool = True + min_decay_coef: float = 0.5 + use_calibration: bool = False + reset_opts: bool = False + # training params + dataset_name: str = "halfcheetah-medium-v2" + batch_size: int = 256 + num_offline_updates: int = 1_000_000 + num_online_updates: int = 1_000_000 + num_warmup_steps: int = 0 + normalize_reward: bool = False + normalize_states: bool = False + # evaluation params + eval_episodes: int = 10 + eval_every: int = 5000 + # general params + train_seed: int = 10 + eval_seed: int = 42 + + def __post_init__(self): + self.name = f"{self.name}-{self.dataset_name}-{str(uuid.uuid4())[:8]}" + + +# def pytorch_init(fan_in: float) -> Callable: +# """ +# Default init for PyTorch Linear layer weights and biases: +# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html +# """ +# bound = math.sqrt(1 / fan_in) +# +# def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: +# return jax.random.uniform( +# key, shape=shape, minval=-bound, maxval=bound, dtype=dtype +# ) +# +# return _init +# +# +# def uniform_init(bound: float) -> Callable: +# def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: +# return jax.random.uniform( +# key, shape=shape, minval=-bound, maxval=bound, dtype=dtype +# ) +# +# return _init +# +# +# def identity(x: Any) -> Any: +# return x +# +# +# class DetActor(nn.Module): +# action_dim: int +# hidden_dim: int = 256 +# layernorm: bool = True +# n_hiddens: int = 3 +# +# @nn.compact +# def __call__(self, state: jax.Array) -> jax.Array: +# s_d, h_d = state.shape[-1], self.hidden_dim +# # Initialization as in the EDAC paper +# layers = [ +# nn.Dense( +# self.hidden_dim, +# kernel_init=pytorch_init(s_d), +# bias_init=nn.initializers.constant(0.1), +# ), +# nn.relu, +# nn.LayerNorm() if self.layernorm else identity, +# ] +# for _ in range(self.n_hiddens - 1): +# layers += [ +# nn.Dense( +# self.hidden_dim, +# kernel_init=pytorch_init(h_d), +# bias_init=nn.initializers.constant(0.1), +# ), +# nn.relu, +# nn.LayerNorm() if self.layernorm else identity, +# ] +# layers += [ +# nn.Dense( +# self.action_dim, +# kernel_init=uniform_init(1e-3), +# bias_init=uniform_init(1e-3), +# ), +# nn.tanh, +# ] +# net = nn.Sequential(layers) +# actions = net(state) +# return actions +# +# +# class Critic(nn.Module): +# hidden_dim: int = 256 +# layernorm: bool = True +# n_hiddens: int = 3 +# +# @nn.compact +# def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: +# s_d, a_d, h_d = state.shape[-1], action.shape[-1], self.hidden_dim +# # Initialization as in the EDAC paper +# layers = [ +# nn.Dense( +# self.hidden_dim, +# kernel_init=pytorch_init(s_d + a_d), +# bias_init=nn.initializers.constant(0.1), +# ), +# nn.relu, +# nn.LayerNorm() if self.layernorm else identity, +# ] +# for _ in range(self.n_hiddens - 1): +# layers += [ +# nn.Dense( +# self.hidden_dim, +# kernel_init=pytorch_init(h_d), +# bias_init=nn.initializers.constant(0.1), +# ), +# nn.relu, +# nn.LayerNorm() if self.layernorm else identity, +# ] +# layers += [ +# nn.Dense(1, kernel_init=uniform_init(3e-3), bias_init=uniform_init(3e-3)) +# ] +# network = nn.Sequential(layers) +# state_action = jnp.hstack([state, action]) +# out = network(state_action).squeeze(-1) +# return out +# +# +# class EnsembleCritic(nn.Module): +# hidden_dim: int = 256 +# num_critics: int = 10 +# layernorm: bool = True +# n_hiddens: int = 3 +# +# @nn.compact +# def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: +# ensemble = nn.vmap( +# target=Critic, +# in_axes=None, +# out_axes=0, +# variable_axes={"params": 0}, +# split_rngs={"params": True}, +# axis_size=self.num_critics, +# ) +# q_values = ensemble(self.hidden_dim, self.layernorm, self.n_hiddens)( +# state, action +# ) +# return q_values +# +# +# def qlearning_dataset( +# env: gym.Env, +# dataset: Dict = None, +# terminate_on_end: bool = False, +# **kwargs, +# ) -> Dict: +# if dataset is None: +# dataset = env.get_dataset(**kwargs) +# +# N = dataset["rewards"].shape[0] +# obs_ = [] +# next_obs_ = [] +# action_ = [] +# next_action_ = [] +# reward_ = [] +# done_ = [] +# +# # The newer version of the dataset adds an explicit +# # timeouts field. Keep old method for backwards compatability. +# use_timeouts = "timeouts" in dataset +# +# episode_step = 0 +# for i in range(N - 1): +# obs = dataset["observations"][i].astype(np.float32) +# new_obs = dataset["observations"][i + 1].astype(np.float32) +# action = dataset["actions"][i].astype(np.float32) +# new_action = dataset["actions"][i + 1].astype(np.float32) +# reward = dataset["rewards"][i].astype(np.float32) +# done_bool = bool(dataset["terminals"][i]) +# +# if use_timeouts: +# final_timestep = dataset["timeouts"][i] +# else: +# final_timestep = episode_step == env._max_episode_steps - 1 +# if (not terminate_on_end) and final_timestep: +# # Skip this transition +# episode_step = 0 +# continue +# if done_bool or final_timestep: +# episode_step = 0 +# +# obs_.append(obs) +# next_obs_.append(new_obs) +# action_.append(action) +# next_action_.append(new_action) +# reward_.append(reward) +# done_.append(done_bool) +# episode_step += 1 +# +# return { +# "observations": np.array(obs_), +# "actions": np.array(action_), +# "next_observations": np.array(next_obs_), +# "next_actions": np.array(next_action_), +# "rewards": np.array(reward_), +# "terminals": np.array(done_), +# } +# +# +# def compute_mean_std(states: jax.Array, eps: float) -> Tuple[jax.Array, jax.Array]: +# mean = states.mean(0) +# std = states.std(0) + eps +# return mean, std +# +# +# def normalize_states(states: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array: +# return (states - mean) / std +# +# +# @chex.dataclass +# class ReplayBuffer: +# data: Dict[str, jax.Array] = None +# mean: float = 0 +# std: float = 1 +# +# def create_from_d4rl( +# self, +# dataset_name: str, +# normalize_reward: bool = False, +# is_normalize: bool = False, +# ): +# d4rl_data = qlearning_dataset(gym.make(dataset_name)) +# buffer = { +# "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), +# "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), +# "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), +# "next_states": jnp.asarray( +# d4rl_data["next_observations"], dtype=jnp.float32 +# ), +# "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), +# "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), +# } +# if is_normalize: +# self.mean, self.std = compute_mean_std(buffer["states"], eps=1e-3) +# buffer["states"] = normalize_states(buffer["states"], self.mean, self.std) +# buffer["next_states"] = normalize_states( +# buffer["next_states"], self.mean, self.std +# ) +# if normalize_reward: +# buffer["rewards"] = ReplayBuffer.normalize_reward( +# dataset_name, buffer["rewards"] +# ) +# self.data = buffer +# +# @property +# def size(self) -> int: +# # WARN: It will use len of the dataclass, i.e. number of fields. +# return self.data["states"].shape[0] +# +# def sample_batch( +# self, key: jax.random.PRNGKey, batch_size: int +# ) -> Dict[str, jax.Array]: +# indices = jax.random.randint( +# key, shape=(batch_size,), minval=0, maxval=self.size +# ) +# batch = jax.tree_map(lambda arr: arr[indices], self.data) +# return batch +# +# def get_moments(self, modality: str) -> Tuple[jax.Array, jax.Array]: +# mean = self.data[modality].mean(0) +# std = self.data[modality].std(0) +# return mean, std +# +# @staticmethod +# def normalize_reward(dataset_name: str, rewards: jax.Array) -> jax.Array: +# if "antmaze" in dataset_name: +# return rewards * 100.0 # like in LAPO +# else: +# raise NotImplementedError( +# "Reward normalization is implemented only for AntMaze yet!" +# ) +# +# +# class Dataset(object): +# def __init__(self, observations: np.ndarray, actions: np.ndarray, +# rewards: np.ndarray, masks: np.ndarray, +# dones_float: np.ndarray, next_observations: np.ndarray, +# next_actions: np.ndarray, +# mc_returns: np.ndarray, +# size: int): +# self.observations = observations +# self.actions = actions +# self.rewards = rewards +# self.masks = masks +# self.dones_float = dones_float +# self.next_observations = next_observations +# self.next_actions = next_actions +# self.mc_returns = mc_returns +# self.size = size +# +# def sample(self, batch_size: int) -> Dict[str, np.ndarray]: +# indx = np.random.randint(self.size, size=batch_size) +# return { +# "states": self.observations[indx], +# "actions": self.actions[indx], +# "rewards": self.rewards[indx], +# "dones": self.dones_float[indx], +# "next_states": self.next_observations[indx], +# "next_actions": self.next_actions[indx], +# "mc_returns": self.mc_returns[indx], +# } +# +# +# class OnlineReplayBuffer(Dataset): +# def __init__(self, observation_space: gym.spaces.Box, action_dim: int, +# capacity: int): +# +# observations = np.empty((capacity, *observation_space.shape), +# dtype=observation_space.dtype) +# actions = np.empty((capacity, action_dim), dtype=np.float32) +# rewards = np.empty((capacity,), dtype=np.float32) +# mc_returns = np.empty((capacity,), dtype=np.float32) +# masks = np.empty((capacity,), dtype=np.float32) +# dones_float = np.empty((capacity,), dtype=np.float32) +# next_observations = np.empty((capacity, *observation_space.shape), +# dtype=observation_space.dtype) +# next_actions = np.empty((capacity, action_dim), dtype=np.float32) +# super().__init__(observations=observations, +# actions=actions, +# rewards=rewards, +# masks=masks, +# dones_float=dones_float, +# next_observations=next_observations, +# next_actions=next_actions, +# mc_returns=mc_returns, +# size=0) +# +# self.size = 0 +# +# self.insert_index = 0 +# self.capacity = capacity +# +# def initialize_with_dataset(self, dataset: Dataset, +# num_samples=None): +# assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' +# +# dataset_size = len(dataset.observations) +# +# if num_samples is None: +# num_samples = dataset_size +# else: +# num_samples = min(dataset_size, num_samples) +# assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' +# +# if num_samples < dataset_size: +# perm = np.random.permutation(dataset_size) +# indices = perm[:num_samples] +# else: +# indices = np.arange(num_samples) +# +# self.observations[:num_samples] = dataset.observations[indices] +# self.actions[:num_samples] = dataset.actions[indices] +# self.rewards[:num_samples] = dataset.rewards[indices] +# self.masks[:num_samples] = dataset.masks[indices] +# self.dones_float[:num_samples] = dataset.dones_float[indices] +# self.next_observations[:num_samples] = dataset.next_observations[ +# indices] +# self.next_actions[:num_samples] = dataset.next_actions[ +# indices] +# self.mc_returns[:num_samples] = dataset.mc_returns[indices] +# +# self.insert_index = num_samples +# self.size = num_samples +# +# def insert(self, observation: np.ndarray, action: np.ndarray, +# reward: float, mask: float, done_float: float, +# next_observation: np.ndarray, +# next_action: np.ndarray, mc_return: np.ndarray): +# self.observations[self.insert_index] = observation +# self.actions[self.insert_index] = action +# self.rewards[self.insert_index] = reward +# self.masks[self.insert_index] = mask +# self.dones_float[self.insert_index] = done_float +# self.next_observations[self.insert_index] = next_observation +# self.next_actions[self.insert_index] = next_action +# self.mc_returns[self.insert_index] = mc_return +# +# self.insert_index = (self.insert_index + 1) % self.capacity +# self.size = min(self.size + 1, self.capacity) +# +# +# class D4RLDataset(Dataset): +# def __init__(self, +# env: gym.Env, +# env_name: str, +# normalize_reward: bool, +# discount: float, +# clip_to_eps: bool = False, +# eps: float = 1e-5): +# d4rl_data = qlearning_dataset(env, env_name, normalize_reward=normalize_reward, discount=discount) +# dataset = { +# "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), +# "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), +# "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), +# "next_states": jnp.asarray(d4rl_data["next_observations"], dtype=jnp.float32), +# "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), +# "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), +# "mc_returns": jnp.asarray(d4rl_data["mc_returns"], dtype=jnp.float32) +# } +# +# super().__init__(dataset['states'].astype(np.float32), +# actions=dataset['actions'].astype(np.float32), +# rewards=dataset['rewards'].astype(np.float32), +# masks=1.0 - dataset['dones'].astype(np.float32), +# dones_float=dataset['dones'].astype(np.float32), +# next_observations=dataset['next_states'].astype( +# np.float32), +# next_actions=dataset["next_actions"], +# mc_returns=dataset["mc_returns"], +# size=len(dataset['states'])) +# +# +# @chex.dataclass(frozen=True) +# class Metrics: +# accumulators: Dict[str, Tuple[jax.Array, jax.Array]] +# +# @staticmethod +# def create(metrics: Sequence[str]) -> "Metrics": +# init_metrics = {key: (jnp.array([0.0]), jnp.array([0.0])) for key in metrics} +# return Metrics(accumulators=init_metrics) +# +# def update(self, updates: Dict[str, jax.Array]) -> "Metrics": +# new_accumulators = deepcopy(self.accumulators) +# for key, value in updates.items(): +# acc, steps = new_accumulators[key] +# new_accumulators[key] = (acc + value, steps + 1) +# +# return self.replace(accumulators=new_accumulators) +# +# def compute(self) -> Dict[str, np.ndarray]: +# # cumulative_value / total_steps +# return {k: np.array(v[0] / v[1]) for k, v in self.accumulators.items()} +# +# +# def normalize( +# arr: jax.Array, mean: jax.Array, std: jax.Array, eps: float = 1e-8 +# ) -> jax.Array: +# return (arr - mean) / (std + eps) +# +# +# def make_env(env_name: str, seed: int) -> gym.Env: +# env = gym.make(env_name) +# env.seed(seed) +# env.action_space.seed(seed) +# env.observation_space.seed(seed) +# return env +# +# +# def wrap_env( +# env: gym.Env, +# state_mean: Union[np.ndarray, float] = 0.0, +# state_std: Union[np.ndarray, float] = 1.0, +# reward_scale: float = 1.0, +# ) -> gym.Env: +# # PEP 8: E731 do not assign a lambda expression, use a def +# def normalize_state(state: np.ndarray) -> np.ndarray: +# return ( +# state - state_mean +# ) / state_std # epsilon should be already added in std. +# +# def scale_reward(reward: float) -> float: +# # Please be careful, here reward is multiplied by scale! +# return reward_scale * reward +# +# env = gym.wrappers.TransformObservation(env, normalize_state) +# if reward_scale != 1.0: +# env = gym.wrappers.TransformReward(env, scale_reward) +# return env +# +# +# def make_env_and_dataset(env_name: str, +# seed: int, +# normalize_reward: bool, +# discount: float) -> Tuple[gym.Env, D4RLDataset]: +# env = gym.make(env_name) +# +# env.seed(seed) +# env.action_space.seed(seed) +# env.observation_space.seed(seed) +# +# dataset = D4RLDataset(env, env_name, normalize_reward, discount=discount) +# +# return env, dataset +# +# +# def is_goal_reached(reward: float, info: Dict) -> bool: +# if "goal_achieved" in info: +# return info["goal_achieved"] +# return reward > 0 # Assuming that reaching target is a positive reward +# +# +# def evaluate( +# env: gym.Env, +# params: jax.Array, +# action_fn: Callable, +# num_episodes: int, +# seed: int, +# ) -> np.ndarray: +# env.seed(seed) +# env.action_space.seed(seed) +# env.observation_space.seed(seed) +# +# returns = [] +# for _ in trange(num_episodes, desc="Eval", leave=False): +# obs, done = env.reset(), False +# total_reward = 0.0 +# while not done: +# action = np.asarray(jax.device_get(action_fn(params, obs))) +# obs, reward, done, _ = env.step(action) +# total_reward += reward +# returns.append(total_reward) +# +# return np.array(returns) +# +# +# class CriticTrainState(TrainState): +# target_params: FrozenDict +# +# +# class ActorTrainState(TrainState): +# target_params: FrozenDict +# +# +# @jax.jit +# def update_actor( +# key: jax.random.PRNGKey, +# actor: TrainState, +# critic: TrainState, +# batch: Dict[str, jax.Array], +# beta: float, +# tau: float, +# normalize_q: bool, +# metrics: Metrics, +# ) -> Tuple[jax.random.PRNGKey, TrainState, jax.Array, Metrics]: +# key, random_action_key = jax.random.split(key, 2) +# +# def actor_loss_fn(params): +# actions = actor.apply_fn(params, batch["states"]) +# +# bc_penalty = ((actions - batch["actions"]) ** 2).sum(-1) +# q_values = critic.apply_fn(critic.params, batch["states"], actions).min(0) +# # lmbda = 1 +# # # if normalize_q: +# lmbda = jax.lax.stop_gradient(1 / jax.numpy.abs(q_values).mean()) +# +# loss = (beta * bc_penalty - lmbda * q_values).mean() +# +# # logging stuff +# random_actions = jax.random.uniform(random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0) +# new_metrics = metrics.update({ +# "actor_loss": loss, +# "bc_mse_policy": bc_penalty.mean(), +# "bc_mse_random": ((random_actions - batch["actions"]) ** 2).sum(-1).mean(), +# "action_mse": ((actions - batch["actions"]) ** 2).mean() +# }) +# return loss, new_metrics +# +# grads, new_metrics = jax.grad(actor_loss_fn, has_aux=True)(actor.params) +# new_actor = actor.apply_gradients(grads=grads) +# +# new_actor = new_actor.replace( +# target_params=optax.incremental_update(actor.params, actor.target_params, tau) +# ) +# new_critic = critic.replace( +# target_params=optax.incremental_update(critic.params, critic.target_params, tau) +# ) +# +# return key, new_actor, new_critic, new_metrics +# +# +# def update_critic( +# key: jax.random.PRNGKey, +# actor: TrainState, +# critic: CriticTrainState, +# batch: Dict[str, jax.Array], +# gamma: float, +# beta: float, +# tau: float, +# policy_noise: float, +# noise_clip: float, +# use_calibration: bool, +# metrics: Metrics, +# ) -> Tuple[jax.random.PRNGKey, TrainState, Metrics]: +# key, actions_key = jax.random.split(key) +# +# next_actions = actor.apply_fn(actor.target_params, batch["next_states"]) +# noise = jax.numpy.clip( +# (jax.random.normal(actions_key, next_actions.shape) * policy_noise), +# -noise_clip, +# noise_clip, +# ) +# next_actions = jax.numpy.clip(next_actions + noise, -1, 1) +# +# bc_penalty = ((next_actions - batch["next_actions"]) ** 2).sum(-1) +# next_q = critic.apply_fn(critic.target_params, batch["next_states"], next_actions).min(0) +# # lower_bounds = jax.numpy.repeat(batch['mc_returns'].reshape(-1, 1), next_q.shape[1], axis=1) +# next_q = next_q - beta * bc_penalty +# target_q = jax.lax.cond( +# use_calibration, +# lambda: jax.numpy.maximum(batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, batch['mc_returns']), +# lambda: batch["rewards"] + (1 - batch["dones"]) * gamma * next_q +# ) +# +# def critic_loss_fn(critic_params): +# # [N, batch_size] - [1, batch_size] +# q = critic.apply_fn(critic_params, batch["states"], batch["actions"]) +# q_min = q.min(0).mean() +# loss = ((q - target_q[None, ...]) ** 2).mean(1).sum(0) +# return loss, q_min +# +# (loss, q_min), grads = jax.value_and_grad(critic_loss_fn, has_aux=True)(critic.params) +# new_critic = critic.apply_gradients(grads=grads) +# new_metrics = metrics.update({ +# "critic_loss": loss, +# "q_min": q_min, +# }) +# return key, new_critic, new_metrics +# +# +# @jax.jit +# def update_td3( +# key: jax.random.PRNGKey, +# actor: TrainState, +# critic: CriticTrainState, +# batch: Dict[str, Any], +# metrics: Metrics, +# gamma: float, +# actor_bc_coef: float, +# critic_bc_coef: float, +# tau: float, +# policy_noise: float, +# noise_clip: float, +# normalize_q: bool, +# use_calibration: bool, +# ): +# key, new_critic, new_metrics = update_critic( +# key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics +# ) +# key, new_actor, new_critic, new_metrics = update_actor(key, actor, +# new_critic, batch, actor_bc_coef, tau, normalize_q, +# new_metrics) +# return key, new_actor, new_critic, new_metrics +# +# +# @jax.jit +# def update_td3_no_targets( +# key: jax.random.PRNGKey, +# actor: TrainState, +# critic: CriticTrainState, +# batch: Dict[str, Any], +# gamma: float, +# metrics: Metrics, +# actor_bc_coef: float, +# critic_bc_coef: float, +# tau: float, +# policy_noise: float, +# noise_clip: float, +# use_calibration: bool, +# ): +# key, new_critic, new_metrics = update_critic( +# key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics +# ) +# return key, actor, new_critic, new_metrics +# +# +# def action_fn(actor: TrainState) -> Callable: +# @jax.jit +# def _action_fn(obs: jax.Array) -> jax.Array: +# action = actor.apply_fn(actor.params, obs) +# return action +# +# return _action_fn + + +@pyrallis.wrap() +def main(config: Config): + pyrallis.dump(config, open('run_config.yaml', 'w')) + # # config.actor_bc_coef = config.critic_bc_coef * config.bc_coef_mul + # dict_config = asdict(config) + # dict_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") + # is_env_with_goal = config.dataset_name.startswith(ENVS_WITH_GOAL) + # np.random.seed(config.train_seed) + # random.seed(config.train_seed) + # + # wandb.init( + # config=dict_config, + # project=config.project, + # group=config.group, + # name=config.name, + # id=str(uuid.uuid4()), + # ) + # buffer = ReplayBuffer() + # buffer.create_from_d4rl(config.dataset_name, config.normalize_reward, config.normalize_states) + # + # key = jax.random.PRNGKey(seed=config.train_seed) + # key, actor_key, critic_key = jax.random.split(key, 3) + # + # init_state = buffer.data["states"][0][None, ...] + # init_action = buffer.data["actions"][0][None, ...] + # + # actor_module = DetActor(action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, layernorm=config.actor_ln, + # n_hiddens=config.actor_n_hiddens) + # actor = ActorTrainState.create( + # apply_fn=actor_module.apply, + # params=actor_module.init(actor_key, init_state), + # target_params=actor_module.init(actor_key, init_state), + # tx=optax.adam(learning_rate=config.actor_learning_rate), + # ) + # + # critic_module = EnsembleCritic(hidden_dim=config.hidden_dim, num_critics=2, layernorm=config.critic_ln, + # n_hiddens=config.critic_n_hiddens) + # critic = CriticTrainState.create( + # apply_fn=critic_module.apply, + # params=critic_module.init(critic_key, init_state, init_action), + # target_params=critic_module.init(critic_key, init_state, init_action), + # tx=optax.adam(learning_rate=config.critic_learning_rate), + # ) + # + # update_td3_partial = partial( + # update_td3, gamma=config.gamma, + # actor_bc_coef=config.actor_bc_coef, critic_bc_coef=config.critic_bc_coef, tau=config.tau, + # policy_noise=config.policy_noise, + # noise_clip=config.noise_clip, + # normalize_q=config.normalize_q, + # use_calibration=config.use_calibration, + # ) + # + # update_td3_no_targets_partial = partial( + # update_td3_no_targets, gamma=config.gamma, + # actor_bc_coef=config.actor_bc_coef, critic_bc_coef=config.critic_bc_coef, tau=config.tau, + # policy_noise=config.policy_noise, + # noise_clip=config.noise_clip, + # use_calibration=config.use_calibration, + # ) + # + # def td3_loop_update_step(i, carry): + # key, batch_key = jax.random.split(carry["key"]) + # batch = carry["buffer"].sample_batch(batch_key, batch_size=config.batch_size) + # + # full_update = partial(update_td3_partial, + # key=key, + # actor=carry["actor"], + # critic=carry["critic"], + # batch=batch, + # metrics=carry["metrics"]) + # + # update = partial(update_td3_no_targets_partial, + # key=key, + # actor=carry["actor"], + # critic=carry["critic"], + # batch=batch, + # metrics=carry["metrics"]) + # + # key, new_actor, new_critic, new_metrics = jax.lax.cond(carry["delayed_updates"][i], full_update, update) + # + # # key, new_actor, new_critic, new_metrics = update_func( + # # key=key, + # # actor=carry["actor"], + # # critic=carry["critic"], + # # batch=batch, + # # metrics=carry["metrics"] + # # ) + # carry.update( + # key=key, actor=new_actor, critic=new_critic, metrics=new_metrics + # ) + # return carry + # + # # metrics + # bc_metrics_to_log = [ + # "critic_loss", "q_min", "actor_loss", "batch_entropy", + # "bc_mse_policy", "bc_mse_random", "action_mse" + # ] + # # shared carry for update loops + # carry = { + # "key": key, + # "actor": actor, + # "critic": critic, + # "buffer": buffer, + # "delayed_updates": jax.numpy.equal( + # jax.numpy.arange(config.num_offline_updates + config.num_online_updates) % config.policy_freq, 0 + # ).astype(int) + # } + # + # # Online + offline tuning + # env, dataset = make_env_and_dataset(config.dataset_name, config.train_seed, False, discount=config.gamma) + # eval_env, _ = make_env_and_dataset(config.dataset_name, config.eval_seed, False, discount=config.gamma) + # + # max_steps = env._max_episode_steps + # + # action_dim = env.action_space.shape[0] + # replay_buffer = OnlineReplayBuffer(env.observation_space, action_dim, + # config.replay_buffer_size) + # replay_buffer.initialize_with_dataset(dataset, None) + # online_buffer = OnlineReplayBuffer(env.observation_space, action_dim, config.replay_buffer_size) + # + # online_batch_size = 0 + # offline_batch_size = config.batch_size + # + # observation, done = env.reset(), False + # episode_step = 0 + # goal_achieved = False + # + # @jax.jit + # def actor_action_fn(params, obs): + # return actor.apply_fn(params, obs) + # + # eval_successes = [] + # train_successes = [] + # print("Offline training") + # for i in tqdm.tqdm(range(config.num_online_updates + config.num_offline_updates), smoothing=0.1): + # carry["metrics"] = Metrics.create(bc_metrics_to_log) + # if i == config.num_offline_updates: + # print("Online training") + # + # online_batch_size = int(config.mixing_ratio * config.batch_size) + # offline_batch_size = config.batch_size - online_batch_size + # # Reset optimizers similar to SPOT + # if config.reset_opts: + # actor = actor.replace( + # opt_state=optax.adam(learning_rate=config.actor_learning_rate).init(actor.params) + # ) + # critic = critic.replace( + # opt_state=optax.adam(learning_rate=config.critic_learning_rate).init(critic.params) + # ) + # + # update_td3_partial = partial( + # update_td3, gamma=config.gamma, + # tau=config.tau, + # policy_noise=config.policy_noise, + # noise_clip=config.noise_clip, + # normalize_q=config.normalize_q, + # use_calibration=config.use_calibration, + # ) + # + # update_td3_no_targets_partial = partial( + # update_td3_no_targets, gamma=config.gamma, + # tau=config.tau, + # policy_noise=config.policy_noise, + # noise_clip=config.noise_clip, + # use_calibration=config.use_calibration, + # ) + # online_log = {} + # + # if i >= config.num_offline_updates: + # episode_step += 1 + # action = np.asarray(actor_action_fn(carry["actor"].params, observation)) + # action = np.array( + # [ + # ( + # action + # + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + # ).clip(-1, 1) + # ] + # )[0] + # + # next_observation, reward, done, info = env.step(action) + # if not goal_achieved: + # goal_achieved = is_goal_reached(reward, info) + # next_action = np.asarray(actor_action_fn(carry["actor"].params, next_observation))[0] + # next_action = np.array( + # [ + # ( + # next_action + # + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + # ).clip(-1, 1) + # ] + # )[0] + # + # if not done or 'TimeLimit.truncated' in info: + # mask = 1.0 + # else: + # mask = 0.0 + # real_done = False + # if done and episode_step < max_steps: + # real_done = True + # + # online_buffer.insert(observation, action, reward, mask, + # float(real_done), next_observation, next_action, 0) + # observation = next_observation + # if done: + # train_successes.append(goal_achieved) + # observation, done = env.reset(), False + # episode_step = 0 + # goal_achieved = False + # + # if config.num_offline_updates <= i < config.num_offline_updates + config.num_warmup_steps: + # continue + # + # offline_batch = replay_buffer.sample(offline_batch_size) + # online_batch = online_buffer.sample(online_batch_size) + # batch = concat_batches(offline_batch, online_batch) + # + # if 'antmaze' in config.dataset_name and config.normalize_reward: + # batch["rewards"] *= 100 + # + # ### Update step + # actor_bc_coef = config.actor_bc_coef + # critic_bc_coef = config.critic_bc_coef + # if i >= config.num_offline_updates: + # decay_coef = max(config.min_decay_coef, ( + # config.num_online_updates + config.num_offline_updates - i + config.num_warmup_steps) / config.num_online_updates) + # actor_bc_coef *= decay_coef + # critic_bc_coef *= 0 + # if i % config.policy_freq == 0: + # update_fn = partial(update_td3_partial, + # actor_bc_coef=actor_bc_coef, + # critic_bc_coef=critic_bc_coef, + # key=key, + # actor=carry["actor"], + # critic=carry["critic"], + # batch=batch, + # metrics=carry["metrics"]) + # else: + # update_fn = partial(update_td3_no_targets_partial, + # actor_bc_coef=actor_bc_coef, + # critic_bc_coef=critic_bc_coef, + # key=key, + # actor=carry["actor"], + # critic=carry["critic"], + # batch=batch, + # metrics=carry["metrics"]) + # key, new_actor, new_critic, new_metrics = update_fn() + # carry.update( + # key=key, actor=new_actor, critic=new_critic, metrics=new_metrics + # ) + # + # if i % 1000 == 0: + # mean_metrics = carry["metrics"].compute() + # common = {f"TD3/{k}": v for k, v in mean_metrics.items()} + # common["actor_bc_coef"] = actor_bc_coef + # common["critic_bc_coef"] = critic_bc_coef + # if i < config.num_offline_updates: + # wandb.log({"offline_iter": i, **common}) + # else: + # wandb.log({"online_iter": i - config.num_offline_updates, **common}) + # if i % config.eval_every == 0 or i == config.num_offline_updates + config.num_online_updates - 1 or i == config.num_offline_updates - 1: + # eval_returns, success_rate = evaluate(eval_env, carry["actor"].params, actor_action_fn, + # config.eval_episodes, + # seed=config.eval_seed) + # normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + # eval_successes.append(success_rate) + # if is_env_with_goal: + # online_log["train/regret"] = np.mean(1 - np.array(train_successes)) + # offline_log = { + # "eval/return_mean": np.mean(eval_returns), + # "eval/return_std": np.std(eval_returns), + # "eval/normalized_score_mean": np.mean(normalized_score), + # "eval/normalized_score_std": np.std(normalized_score), + # "eval/success_rate": success_rate + # } + # offline_log.update(online_log) + # wandb.log(offline_log) + + +if __name__ == "__main__": + main() diff --git a/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml new file mode 100644 index 00000000..4da606f8 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.002 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.002 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-large-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-large-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/large_play_v2.yaml b/configs/finetune/rebrac/antmaze/large_play_v2.yaml new file mode 100644 index 00000000..0210a3ab --- /dev/null +++ b/configs/finetune/rebrac/antmaze/large_play_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.002 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.001 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-large-play-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-large-play-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml new file mode 100644 index 00000000..341292eb --- /dev/null +++ b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.001 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.0 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-medium-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-medium-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/medium_play_v2.yaml b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml new file mode 100644 index 00000000..4cf200be --- /dev/null +++ b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.001 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.0005 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-medium-play-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-medium-play-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml new file mode 100644 index 00000000..f2639b9c --- /dev/null +++ b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.003 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.001 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-umaze-diverse-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-umaze-diverse-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/antmaze/umaze_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_v2.yaml new file mode 100644 index 00000000..e173e309 --- /dev/null +++ b/configs/finetune/rebrac/antmaze/umaze_v2.yaml @@ -0,0 +1,36 @@ +actor_bc_coef: 0.003 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 1.0 +critic_learning_rate: 0.00005 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: antmaze-umaze-v2 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.999 +group: rebrac-finetune-antmaze-umaze-v2 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false + diff --git a/configs/finetune/rebrac/door/cloned_v1.yaml b/configs/finetune/rebrac/door/cloned_v1.yaml new file mode 100644 index 00000000..06db15cd --- /dev/null +++ b/configs/finetune/rebrac/door/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.01 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.1 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: door-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-door-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/hammer/cloned_v1.yaml b/configs/finetune/rebrac/hammer/cloned_v1.yaml new file mode 100644 index 00000000..c55be6f2 --- /dev/null +++ b/configs/finetune/rebrac/hammer/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.1 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.5 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: hammer-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-hammer-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/pen/cloned_v1.yaml b/configs/finetune/rebrac/pen/cloned_v1.yaml new file mode 100644 index 00000000..9144ab9c --- /dev/null +++ b/configs/finetune/rebrac/pen/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.05 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.5 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: pen-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-pen-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false diff --git a/configs/finetune/rebrac/relocate/cloned_v1.yaml b/configs/finetune/rebrac/relocate/cloned_v1.yaml new file mode 100644 index 00000000..fefa5048 --- /dev/null +++ b/configs/finetune/rebrac/relocate/cloned_v1.yaml @@ -0,0 +1,35 @@ +actor_bc_coef: 0.1 +actor_learning_rate: 0.0003 +actor_ln: false +actor_n_hiddens: 3 +batch_size: 256 +critic_bc_coef: 0.01 +critic_learning_rate: 0.0003 +critic_ln: true +critic_n_hiddens: 3 +dataset_name: relocate-cloned-v1 +eval_episodes: 100 +eval_every: 50000 +eval_seed: 42 +expl_noise: 0.0 +gamma: 0.99 +group: rebrac-finetune-relocate-cloned-v1 +hidden_dim: 256 +min_decay_coef: 0.5 +mixing_ratio: 0.5 +name: rebrac-finetune +noise_clip: 0.5 +normalize_q: true +normalize_reward: false +normalize_states: false +num_offline_updates: 1000000 +num_online_updates: 1000000 +num_warmup_steps: 0 +policy_freq: 2 +policy_noise: 0.2 +project: CORL +replay_buffer_size: 2000000 +reset_opts: false +tau: 0.005 +train_seed: 0 +use_calibration: false From b969c3fa72a6ff63e1c31838f868e03675cd5fbe Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Thu, 17 Aug 2023 21:18:35 +0200 Subject: [PATCH 02/15] Fix --- algorithms/finetune/rebrac.py | 1837 ++++++++++++++++----------------- 1 file changed, 896 insertions(+), 941 deletions(-) diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py index d29822e6..8dce4239 100644 --- a/algorithms/finetune/rebrac.py +++ b/algorithms/finetune/rebrac.py @@ -10,8 +10,8 @@ import random import chex -# import d4rl # noqa -# import gym +import d4rl # noqa +import gym import jax import numpy as np import optax @@ -81,948 +81,903 @@ def __post_init__(self): self.name = f"{self.name}-{self.dataset_name}-{str(uuid.uuid4())[:8]}" -# def pytorch_init(fan_in: float) -> Callable: -# """ -# Default init for PyTorch Linear layer weights and biases: -# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html -# """ -# bound = math.sqrt(1 / fan_in) -# -# def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: -# return jax.random.uniform( -# key, shape=shape, minval=-bound, maxval=bound, dtype=dtype -# ) -# -# return _init -# -# -# def uniform_init(bound: float) -> Callable: -# def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: -# return jax.random.uniform( -# key, shape=shape, minval=-bound, maxval=bound, dtype=dtype -# ) -# -# return _init -# -# -# def identity(x: Any) -> Any: -# return x -# -# -# class DetActor(nn.Module): -# action_dim: int -# hidden_dim: int = 256 -# layernorm: bool = True -# n_hiddens: int = 3 -# -# @nn.compact -# def __call__(self, state: jax.Array) -> jax.Array: -# s_d, h_d = state.shape[-1], self.hidden_dim -# # Initialization as in the EDAC paper -# layers = [ -# nn.Dense( -# self.hidden_dim, -# kernel_init=pytorch_init(s_d), -# bias_init=nn.initializers.constant(0.1), -# ), -# nn.relu, -# nn.LayerNorm() if self.layernorm else identity, -# ] -# for _ in range(self.n_hiddens - 1): -# layers += [ -# nn.Dense( -# self.hidden_dim, -# kernel_init=pytorch_init(h_d), -# bias_init=nn.initializers.constant(0.1), -# ), -# nn.relu, -# nn.LayerNorm() if self.layernorm else identity, -# ] -# layers += [ -# nn.Dense( -# self.action_dim, -# kernel_init=uniform_init(1e-3), -# bias_init=uniform_init(1e-3), -# ), -# nn.tanh, -# ] -# net = nn.Sequential(layers) -# actions = net(state) -# return actions -# -# -# class Critic(nn.Module): -# hidden_dim: int = 256 -# layernorm: bool = True -# n_hiddens: int = 3 -# -# @nn.compact -# def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: -# s_d, a_d, h_d = state.shape[-1], action.shape[-1], self.hidden_dim -# # Initialization as in the EDAC paper -# layers = [ -# nn.Dense( -# self.hidden_dim, -# kernel_init=pytorch_init(s_d + a_d), -# bias_init=nn.initializers.constant(0.1), -# ), -# nn.relu, -# nn.LayerNorm() if self.layernorm else identity, -# ] -# for _ in range(self.n_hiddens - 1): -# layers += [ -# nn.Dense( -# self.hidden_dim, -# kernel_init=pytorch_init(h_d), -# bias_init=nn.initializers.constant(0.1), -# ), -# nn.relu, -# nn.LayerNorm() if self.layernorm else identity, -# ] -# layers += [ -# nn.Dense(1, kernel_init=uniform_init(3e-3), bias_init=uniform_init(3e-3)) -# ] -# network = nn.Sequential(layers) -# state_action = jnp.hstack([state, action]) -# out = network(state_action).squeeze(-1) -# return out -# -# -# class EnsembleCritic(nn.Module): -# hidden_dim: int = 256 -# num_critics: int = 10 -# layernorm: bool = True -# n_hiddens: int = 3 -# -# @nn.compact -# def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: -# ensemble = nn.vmap( -# target=Critic, -# in_axes=None, -# out_axes=0, -# variable_axes={"params": 0}, -# split_rngs={"params": True}, -# axis_size=self.num_critics, -# ) -# q_values = ensemble(self.hidden_dim, self.layernorm, self.n_hiddens)( -# state, action -# ) -# return q_values -# -# -# def qlearning_dataset( -# env: gym.Env, -# dataset: Dict = None, -# terminate_on_end: bool = False, -# **kwargs, -# ) -> Dict: -# if dataset is None: -# dataset = env.get_dataset(**kwargs) -# -# N = dataset["rewards"].shape[0] -# obs_ = [] -# next_obs_ = [] -# action_ = [] -# next_action_ = [] -# reward_ = [] -# done_ = [] -# -# # The newer version of the dataset adds an explicit -# # timeouts field. Keep old method for backwards compatability. -# use_timeouts = "timeouts" in dataset -# -# episode_step = 0 -# for i in range(N - 1): -# obs = dataset["observations"][i].astype(np.float32) -# new_obs = dataset["observations"][i + 1].astype(np.float32) -# action = dataset["actions"][i].astype(np.float32) -# new_action = dataset["actions"][i + 1].astype(np.float32) -# reward = dataset["rewards"][i].astype(np.float32) -# done_bool = bool(dataset["terminals"][i]) -# -# if use_timeouts: -# final_timestep = dataset["timeouts"][i] -# else: -# final_timestep = episode_step == env._max_episode_steps - 1 -# if (not terminate_on_end) and final_timestep: -# # Skip this transition -# episode_step = 0 -# continue -# if done_bool or final_timestep: -# episode_step = 0 -# -# obs_.append(obs) -# next_obs_.append(new_obs) -# action_.append(action) -# next_action_.append(new_action) -# reward_.append(reward) -# done_.append(done_bool) -# episode_step += 1 -# -# return { -# "observations": np.array(obs_), -# "actions": np.array(action_), -# "next_observations": np.array(next_obs_), -# "next_actions": np.array(next_action_), -# "rewards": np.array(reward_), -# "terminals": np.array(done_), -# } -# -# -# def compute_mean_std(states: jax.Array, eps: float) -> Tuple[jax.Array, jax.Array]: -# mean = states.mean(0) -# std = states.std(0) + eps -# return mean, std -# -# -# def normalize_states(states: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array: -# return (states - mean) / std -# -# -# @chex.dataclass -# class ReplayBuffer: -# data: Dict[str, jax.Array] = None -# mean: float = 0 -# std: float = 1 -# -# def create_from_d4rl( -# self, -# dataset_name: str, -# normalize_reward: bool = False, -# is_normalize: bool = False, -# ): -# d4rl_data = qlearning_dataset(gym.make(dataset_name)) -# buffer = { -# "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), -# "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), -# "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), -# "next_states": jnp.asarray( -# d4rl_data["next_observations"], dtype=jnp.float32 -# ), -# "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), -# "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), -# } -# if is_normalize: -# self.mean, self.std = compute_mean_std(buffer["states"], eps=1e-3) -# buffer["states"] = normalize_states(buffer["states"], self.mean, self.std) -# buffer["next_states"] = normalize_states( -# buffer["next_states"], self.mean, self.std -# ) -# if normalize_reward: -# buffer["rewards"] = ReplayBuffer.normalize_reward( -# dataset_name, buffer["rewards"] -# ) -# self.data = buffer -# -# @property -# def size(self) -> int: -# # WARN: It will use len of the dataclass, i.e. number of fields. -# return self.data["states"].shape[0] -# -# def sample_batch( -# self, key: jax.random.PRNGKey, batch_size: int -# ) -> Dict[str, jax.Array]: -# indices = jax.random.randint( -# key, shape=(batch_size,), minval=0, maxval=self.size -# ) -# batch = jax.tree_map(lambda arr: arr[indices], self.data) -# return batch -# -# def get_moments(self, modality: str) -> Tuple[jax.Array, jax.Array]: -# mean = self.data[modality].mean(0) -# std = self.data[modality].std(0) -# return mean, std -# -# @staticmethod -# def normalize_reward(dataset_name: str, rewards: jax.Array) -> jax.Array: -# if "antmaze" in dataset_name: -# return rewards * 100.0 # like in LAPO -# else: -# raise NotImplementedError( -# "Reward normalization is implemented only for AntMaze yet!" -# ) -# -# -# class Dataset(object): -# def __init__(self, observations: np.ndarray, actions: np.ndarray, -# rewards: np.ndarray, masks: np.ndarray, -# dones_float: np.ndarray, next_observations: np.ndarray, -# next_actions: np.ndarray, -# mc_returns: np.ndarray, -# size: int): -# self.observations = observations -# self.actions = actions -# self.rewards = rewards -# self.masks = masks -# self.dones_float = dones_float -# self.next_observations = next_observations -# self.next_actions = next_actions -# self.mc_returns = mc_returns -# self.size = size -# -# def sample(self, batch_size: int) -> Dict[str, np.ndarray]: -# indx = np.random.randint(self.size, size=batch_size) -# return { -# "states": self.observations[indx], -# "actions": self.actions[indx], -# "rewards": self.rewards[indx], -# "dones": self.dones_float[indx], -# "next_states": self.next_observations[indx], -# "next_actions": self.next_actions[indx], -# "mc_returns": self.mc_returns[indx], -# } -# -# -# class OnlineReplayBuffer(Dataset): -# def __init__(self, observation_space: gym.spaces.Box, action_dim: int, -# capacity: int): -# -# observations = np.empty((capacity, *observation_space.shape), -# dtype=observation_space.dtype) -# actions = np.empty((capacity, action_dim), dtype=np.float32) -# rewards = np.empty((capacity,), dtype=np.float32) -# mc_returns = np.empty((capacity,), dtype=np.float32) -# masks = np.empty((capacity,), dtype=np.float32) -# dones_float = np.empty((capacity,), dtype=np.float32) -# next_observations = np.empty((capacity, *observation_space.shape), -# dtype=observation_space.dtype) -# next_actions = np.empty((capacity, action_dim), dtype=np.float32) -# super().__init__(observations=observations, -# actions=actions, -# rewards=rewards, -# masks=masks, -# dones_float=dones_float, -# next_observations=next_observations, -# next_actions=next_actions, -# mc_returns=mc_returns, -# size=0) -# -# self.size = 0 -# -# self.insert_index = 0 -# self.capacity = capacity -# -# def initialize_with_dataset(self, dataset: Dataset, -# num_samples=None): -# assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' -# -# dataset_size = len(dataset.observations) -# -# if num_samples is None: -# num_samples = dataset_size -# else: -# num_samples = min(dataset_size, num_samples) -# assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' -# -# if num_samples < dataset_size: -# perm = np.random.permutation(dataset_size) -# indices = perm[:num_samples] -# else: -# indices = np.arange(num_samples) -# -# self.observations[:num_samples] = dataset.observations[indices] -# self.actions[:num_samples] = dataset.actions[indices] -# self.rewards[:num_samples] = dataset.rewards[indices] -# self.masks[:num_samples] = dataset.masks[indices] -# self.dones_float[:num_samples] = dataset.dones_float[indices] -# self.next_observations[:num_samples] = dataset.next_observations[ -# indices] -# self.next_actions[:num_samples] = dataset.next_actions[ -# indices] -# self.mc_returns[:num_samples] = dataset.mc_returns[indices] -# -# self.insert_index = num_samples -# self.size = num_samples -# -# def insert(self, observation: np.ndarray, action: np.ndarray, -# reward: float, mask: float, done_float: float, -# next_observation: np.ndarray, -# next_action: np.ndarray, mc_return: np.ndarray): -# self.observations[self.insert_index] = observation -# self.actions[self.insert_index] = action -# self.rewards[self.insert_index] = reward -# self.masks[self.insert_index] = mask -# self.dones_float[self.insert_index] = done_float -# self.next_observations[self.insert_index] = next_observation -# self.next_actions[self.insert_index] = next_action -# self.mc_returns[self.insert_index] = mc_return -# -# self.insert_index = (self.insert_index + 1) % self.capacity -# self.size = min(self.size + 1, self.capacity) -# -# -# class D4RLDataset(Dataset): -# def __init__(self, -# env: gym.Env, -# env_name: str, -# normalize_reward: bool, -# discount: float, -# clip_to_eps: bool = False, -# eps: float = 1e-5): -# d4rl_data = qlearning_dataset(env, env_name, normalize_reward=normalize_reward, discount=discount) -# dataset = { -# "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), -# "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), -# "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), -# "next_states": jnp.asarray(d4rl_data["next_observations"], dtype=jnp.float32), -# "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), -# "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), -# "mc_returns": jnp.asarray(d4rl_data["mc_returns"], dtype=jnp.float32) -# } -# -# super().__init__(dataset['states'].astype(np.float32), -# actions=dataset['actions'].astype(np.float32), -# rewards=dataset['rewards'].astype(np.float32), -# masks=1.0 - dataset['dones'].astype(np.float32), -# dones_float=dataset['dones'].astype(np.float32), -# next_observations=dataset['next_states'].astype( -# np.float32), -# next_actions=dataset["next_actions"], -# mc_returns=dataset["mc_returns"], -# size=len(dataset['states'])) -# -# -# @chex.dataclass(frozen=True) -# class Metrics: -# accumulators: Dict[str, Tuple[jax.Array, jax.Array]] -# -# @staticmethod -# def create(metrics: Sequence[str]) -> "Metrics": -# init_metrics = {key: (jnp.array([0.0]), jnp.array([0.0])) for key in metrics} -# return Metrics(accumulators=init_metrics) -# -# def update(self, updates: Dict[str, jax.Array]) -> "Metrics": -# new_accumulators = deepcopy(self.accumulators) -# for key, value in updates.items(): -# acc, steps = new_accumulators[key] -# new_accumulators[key] = (acc + value, steps + 1) -# -# return self.replace(accumulators=new_accumulators) -# -# def compute(self) -> Dict[str, np.ndarray]: -# # cumulative_value / total_steps -# return {k: np.array(v[0] / v[1]) for k, v in self.accumulators.items()} -# -# -# def normalize( -# arr: jax.Array, mean: jax.Array, std: jax.Array, eps: float = 1e-8 -# ) -> jax.Array: -# return (arr - mean) / (std + eps) -# -# -# def make_env(env_name: str, seed: int) -> gym.Env: -# env = gym.make(env_name) -# env.seed(seed) -# env.action_space.seed(seed) -# env.observation_space.seed(seed) -# return env -# -# -# def wrap_env( -# env: gym.Env, -# state_mean: Union[np.ndarray, float] = 0.0, -# state_std: Union[np.ndarray, float] = 1.0, -# reward_scale: float = 1.0, -# ) -> gym.Env: -# # PEP 8: E731 do not assign a lambda expression, use a def -# def normalize_state(state: np.ndarray) -> np.ndarray: -# return ( -# state - state_mean -# ) / state_std # epsilon should be already added in std. -# -# def scale_reward(reward: float) -> float: -# # Please be careful, here reward is multiplied by scale! -# return reward_scale * reward -# -# env = gym.wrappers.TransformObservation(env, normalize_state) -# if reward_scale != 1.0: -# env = gym.wrappers.TransformReward(env, scale_reward) -# return env -# -# -# def make_env_and_dataset(env_name: str, -# seed: int, -# normalize_reward: bool, -# discount: float) -> Tuple[gym.Env, D4RLDataset]: -# env = gym.make(env_name) -# -# env.seed(seed) -# env.action_space.seed(seed) -# env.observation_space.seed(seed) -# -# dataset = D4RLDataset(env, env_name, normalize_reward, discount=discount) -# -# return env, dataset -# -# -# def is_goal_reached(reward: float, info: Dict) -> bool: -# if "goal_achieved" in info: -# return info["goal_achieved"] -# return reward > 0 # Assuming that reaching target is a positive reward -# -# -# def evaluate( -# env: gym.Env, -# params: jax.Array, -# action_fn: Callable, -# num_episodes: int, -# seed: int, -# ) -> np.ndarray: -# env.seed(seed) -# env.action_space.seed(seed) -# env.observation_space.seed(seed) -# -# returns = [] -# for _ in trange(num_episodes, desc="Eval", leave=False): -# obs, done = env.reset(), False -# total_reward = 0.0 -# while not done: -# action = np.asarray(jax.device_get(action_fn(params, obs))) -# obs, reward, done, _ = env.step(action) -# total_reward += reward -# returns.append(total_reward) -# -# return np.array(returns) -# -# -# class CriticTrainState(TrainState): -# target_params: FrozenDict -# -# -# class ActorTrainState(TrainState): -# target_params: FrozenDict -# -# -# @jax.jit -# def update_actor( -# key: jax.random.PRNGKey, -# actor: TrainState, -# critic: TrainState, -# batch: Dict[str, jax.Array], -# beta: float, -# tau: float, -# normalize_q: bool, -# metrics: Metrics, -# ) -> Tuple[jax.random.PRNGKey, TrainState, jax.Array, Metrics]: -# key, random_action_key = jax.random.split(key, 2) -# -# def actor_loss_fn(params): -# actions = actor.apply_fn(params, batch["states"]) -# -# bc_penalty = ((actions - batch["actions"]) ** 2).sum(-1) -# q_values = critic.apply_fn(critic.params, batch["states"], actions).min(0) -# # lmbda = 1 -# # # if normalize_q: -# lmbda = jax.lax.stop_gradient(1 / jax.numpy.abs(q_values).mean()) -# -# loss = (beta * bc_penalty - lmbda * q_values).mean() -# -# # logging stuff -# random_actions = jax.random.uniform(random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0) -# new_metrics = metrics.update({ -# "actor_loss": loss, -# "bc_mse_policy": bc_penalty.mean(), -# "bc_mse_random": ((random_actions - batch["actions"]) ** 2).sum(-1).mean(), -# "action_mse": ((actions - batch["actions"]) ** 2).mean() -# }) -# return loss, new_metrics -# -# grads, new_metrics = jax.grad(actor_loss_fn, has_aux=True)(actor.params) -# new_actor = actor.apply_gradients(grads=grads) -# -# new_actor = new_actor.replace( -# target_params=optax.incremental_update(actor.params, actor.target_params, tau) -# ) -# new_critic = critic.replace( -# target_params=optax.incremental_update(critic.params, critic.target_params, tau) -# ) -# -# return key, new_actor, new_critic, new_metrics -# -# -# def update_critic( -# key: jax.random.PRNGKey, -# actor: TrainState, -# critic: CriticTrainState, -# batch: Dict[str, jax.Array], -# gamma: float, -# beta: float, -# tau: float, -# policy_noise: float, -# noise_clip: float, -# use_calibration: bool, -# metrics: Metrics, -# ) -> Tuple[jax.random.PRNGKey, TrainState, Metrics]: -# key, actions_key = jax.random.split(key) -# -# next_actions = actor.apply_fn(actor.target_params, batch["next_states"]) -# noise = jax.numpy.clip( -# (jax.random.normal(actions_key, next_actions.shape) * policy_noise), -# -noise_clip, -# noise_clip, -# ) -# next_actions = jax.numpy.clip(next_actions + noise, -1, 1) -# -# bc_penalty = ((next_actions - batch["next_actions"]) ** 2).sum(-1) -# next_q = critic.apply_fn(critic.target_params, batch["next_states"], next_actions).min(0) -# # lower_bounds = jax.numpy.repeat(batch['mc_returns'].reshape(-1, 1), next_q.shape[1], axis=1) -# next_q = next_q - beta * bc_penalty -# target_q = jax.lax.cond( -# use_calibration, -# lambda: jax.numpy.maximum(batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, batch['mc_returns']), -# lambda: batch["rewards"] + (1 - batch["dones"]) * gamma * next_q -# ) -# -# def critic_loss_fn(critic_params): -# # [N, batch_size] - [1, batch_size] -# q = critic.apply_fn(critic_params, batch["states"], batch["actions"]) -# q_min = q.min(0).mean() -# loss = ((q - target_q[None, ...]) ** 2).mean(1).sum(0) -# return loss, q_min -# -# (loss, q_min), grads = jax.value_and_grad(critic_loss_fn, has_aux=True)(critic.params) -# new_critic = critic.apply_gradients(grads=grads) -# new_metrics = metrics.update({ -# "critic_loss": loss, -# "q_min": q_min, -# }) -# return key, new_critic, new_metrics -# -# -# @jax.jit -# def update_td3( -# key: jax.random.PRNGKey, -# actor: TrainState, -# critic: CriticTrainState, -# batch: Dict[str, Any], -# metrics: Metrics, -# gamma: float, -# actor_bc_coef: float, -# critic_bc_coef: float, -# tau: float, -# policy_noise: float, -# noise_clip: float, -# normalize_q: bool, -# use_calibration: bool, -# ): -# key, new_critic, new_metrics = update_critic( -# key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics -# ) -# key, new_actor, new_critic, new_metrics = update_actor(key, actor, -# new_critic, batch, actor_bc_coef, tau, normalize_q, -# new_metrics) -# return key, new_actor, new_critic, new_metrics -# -# -# @jax.jit -# def update_td3_no_targets( -# key: jax.random.PRNGKey, -# actor: TrainState, -# critic: CriticTrainState, -# batch: Dict[str, Any], -# gamma: float, -# metrics: Metrics, -# actor_bc_coef: float, -# critic_bc_coef: float, -# tau: float, -# policy_noise: float, -# noise_clip: float, -# use_calibration: bool, -# ): -# key, new_critic, new_metrics = update_critic( -# key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics -# ) -# return key, actor, new_critic, new_metrics -# -# -# def action_fn(actor: TrainState) -> Callable: -# @jax.jit -# def _action_fn(obs: jax.Array) -> jax.Array: -# action = actor.apply_fn(actor.params, obs) -# return action -# -# return _action_fn +def pytorch_init(fan_in: float) -> Callable: + """ + Default init for PyTorch Linear layer weights and biases: + https://pytorch.org/docs/stable/generated/torch.nn.Linear.html + """ + bound = math.sqrt(1 / fan_in) + + def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: + return jax.random.uniform( + key, shape=shape, minval=-bound, maxval=bound, dtype=dtype + ) + + return _init + + +def uniform_init(bound: float) -> Callable: + def _init(key: jax.random.PRNGKey, shape: Tuple, dtype: type) -> jax.Array: + return jax.random.uniform( + key, shape=shape, minval=-bound, maxval=bound, dtype=dtype + ) + + return _init + + +def identity(x: Any) -> Any: + return x + + +class DetActor(nn.Module): + action_dim: int + hidden_dim: int = 256 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array) -> jax.Array: + s_d, h_d = state.shape[-1], self.hidden_dim + # Initialization as in the EDAC paper + layers = [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(s_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + for _ in range(self.n_hiddens - 1): + layers += [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(h_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + layers += [ + nn.Dense( + self.action_dim, + kernel_init=uniform_init(1e-3), + bias_init=uniform_init(1e-3), + ), + nn.tanh, + ] + net = nn.Sequential(layers) + actions = net(state) + return actions + + +class Critic(nn.Module): + hidden_dim: int = 256 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: + s_d, a_d, h_d = state.shape[-1], action.shape[-1], self.hidden_dim + # Initialization as in the EDAC paper + layers = [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(s_d + a_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + for _ in range(self.n_hiddens - 1): + layers += [ + nn.Dense( + self.hidden_dim, + kernel_init=pytorch_init(h_d), + bias_init=nn.initializers.constant(0.1), + ), + nn.relu, + nn.LayerNorm() if self.layernorm else identity, + ] + layers += [ + nn.Dense(1, kernel_init=uniform_init(3e-3), bias_init=uniform_init(3e-3)) + ] + network = nn.Sequential(layers) + state_action = jnp.hstack([state, action]) + out = network(state_action).squeeze(-1) + return out + + +class EnsembleCritic(nn.Module): + hidden_dim: int = 256 + num_critics: int = 10 + layernorm: bool = True + n_hiddens: int = 3 + + @nn.compact + def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: + ensemble = nn.vmap( + target=Critic, + in_axes=None, + out_axes=0, + variable_axes={"params": 0}, + split_rngs={"params": True}, + axis_size=self.num_critics, + ) + q_values = ensemble(self.hidden_dim, self.layernorm, self.n_hiddens)( + state, action + ) + return q_values + + +def qlearning_dataset( + env: gym.Env, + dataset: Dict = None, + terminate_on_end: bool = False, + **kwargs, +) -> Dict: + if dataset is None: + dataset = env.get_dataset(**kwargs) + + N = dataset["rewards"].shape[0] + obs_ = [] + next_obs_ = [] + action_ = [] + next_action_ = [] + reward_ = [] + done_ = [] + + # The newer version of the dataset adds an explicit + # timeouts field. Keep old method for backwards compatability. + use_timeouts = "timeouts" in dataset + + episode_step = 0 + for i in range(N - 1): + obs = dataset["observations"][i].astype(np.float32) + new_obs = dataset["observations"][i + 1].astype(np.float32) + action = dataset["actions"][i].astype(np.float32) + new_action = dataset["actions"][i + 1].astype(np.float32) + reward = dataset["rewards"][i].astype(np.float32) + done_bool = bool(dataset["terminals"][i]) + + if use_timeouts: + final_timestep = dataset["timeouts"][i] + else: + final_timestep = episode_step == env._max_episode_steps - 1 + if (not terminate_on_end) and final_timestep: + # Skip this transition + episode_step = 0 + continue + if done_bool or final_timestep: + episode_step = 0 + + obs_.append(obs) + next_obs_.append(new_obs) + action_.append(action) + next_action_.append(new_action) + reward_.append(reward) + done_.append(done_bool) + episode_step += 1 + + return { + "observations": np.array(obs_), + "actions": np.array(action_), + "next_observations": np.array(next_obs_), + "next_actions": np.array(next_action_), + "rewards": np.array(reward_), + "terminals": np.array(done_), + } + + +def compute_mean_std(states: jax.Array, eps: float) -> Tuple[jax.Array, jax.Array]: + mean = states.mean(0) + std = states.std(0) + eps + return mean, std + + +def normalize_states(states: jax.Array, mean: jax.Array, std: jax.Array) -> jax.Array: + return (states - mean) / std + + +@chex.dataclass +class ReplayBuffer: + data: Dict[str, jax.Array] = None + mean: float = 0 + std: float = 1 + + def create_from_d4rl( + self, + dataset_name: str, + normalize_reward: bool = False, + is_normalize: bool = False, + ): + d4rl_data = qlearning_dataset(gym.make(dataset_name)) + buffer = { + "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), + "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), + "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), + "next_states": jnp.asarray( + d4rl_data["next_observations"], dtype=jnp.float32 + ), + "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), + "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), + } + if is_normalize: + self.mean, self.std = compute_mean_std(buffer["states"], eps=1e-3) + buffer["states"] = normalize_states(buffer["states"], self.mean, self.std) + buffer["next_states"] = normalize_states( + buffer["next_states"], self.mean, self.std + ) + if normalize_reward: + buffer["rewards"] = ReplayBuffer.normalize_reward( + dataset_name, buffer["rewards"] + ) + self.data = buffer + + @property + def size(self) -> int: + # WARN: It will use len of the dataclass, i.e. number of fields. + return self.data["states"].shape[0] + + def sample_batch( + self, key: jax.random.PRNGKey, batch_size: int + ) -> Dict[str, jax.Array]: + indices = jax.random.randint( + key, shape=(batch_size,), minval=0, maxval=self.size + ) + batch = jax.tree_map(lambda arr: arr[indices], self.data) + return batch + + def get_moments(self, modality: str) -> Tuple[jax.Array, jax.Array]: + mean = self.data[modality].mean(0) + std = self.data[modality].std(0) + return mean, std + + @staticmethod + def normalize_reward(dataset_name: str, rewards: jax.Array) -> jax.Array: + if "antmaze" in dataset_name: + return rewards * 100.0 # like in LAPO + else: + raise NotImplementedError( + "Reward normalization is implemented only for AntMaze yet!" + ) + + +class Dataset(object): + def __init__(self, observations: np.ndarray, actions: np.ndarray, + rewards: np.ndarray, masks: np.ndarray, + dones_float: np.ndarray, next_observations: np.ndarray, + next_actions: np.ndarray, + mc_returns: np.ndarray, + size: int): + self.observations = observations + self.actions = actions + self.rewards = rewards + self.masks = masks + self.dones_float = dones_float + self.next_observations = next_observations + self.next_actions = next_actions + self.mc_returns = mc_returns + self.size = size + + def sample(self, batch_size: int) -> Dict[str, np.ndarray]: + indx = np.random.randint(self.size, size=batch_size) + return { + "states": self.observations[indx], + "actions": self.actions[indx], + "rewards": self.rewards[indx], + "dones": self.dones_float[indx], + "next_states": self.next_observations[indx], + "next_actions": self.next_actions[indx], + "mc_returns": self.mc_returns[indx], + } + + +class OnlineReplayBuffer(Dataset): + def __init__(self, observation_space: gym.spaces.Box, action_dim: int, + capacity: int): + + observations = np.empty((capacity, *observation_space.shape), + dtype=observation_space.dtype) + actions = np.empty((capacity, action_dim), dtype=np.float32) + rewards = np.empty((capacity,), dtype=np.float32) + mc_returns = np.empty((capacity,), dtype=np.float32) + masks = np.empty((capacity,), dtype=np.float32) + dones_float = np.empty((capacity,), dtype=np.float32) + next_observations = np.empty((capacity, *observation_space.shape), + dtype=observation_space.dtype) + next_actions = np.empty((capacity, action_dim), dtype=np.float32) + super().__init__(observations=observations, + actions=actions, + rewards=rewards, + masks=masks, + dones_float=dones_float, + next_observations=next_observations, + next_actions=next_actions, + mc_returns=mc_returns, + size=0) + + self.size = 0 + + self.insert_index = 0 + self.capacity = capacity + + def initialize_with_dataset(self, dataset: Dataset, + num_samples=None): + assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' + + dataset_size = len(dataset.observations) + + if num_samples is None: + num_samples = dataset_size + else: + num_samples = min(dataset_size, num_samples) + assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' + + if num_samples < dataset_size: + perm = np.random.permutation(dataset_size) + indices = perm[:num_samples] + else: + indices = np.arange(num_samples) + + self.observations[:num_samples] = dataset.observations[indices] + self.actions[:num_samples] = dataset.actions[indices] + self.rewards[:num_samples] = dataset.rewards[indices] + self.masks[:num_samples] = dataset.masks[indices] + self.dones_float[:num_samples] = dataset.dones_float[indices] + self.next_observations[:num_samples] = dataset.next_observations[ + indices] + self.next_actions[:num_samples] = dataset.next_actions[ + indices] + self.mc_returns[:num_samples] = dataset.mc_returns[indices] + + self.insert_index = num_samples + self.size = num_samples + + def insert(self, observation: np.ndarray, action: np.ndarray, + reward: float, mask: float, done_float: float, + next_observation: np.ndarray, + next_action: np.ndarray, mc_return: np.ndarray): + self.observations[self.insert_index] = observation + self.actions[self.insert_index] = action + self.rewards[self.insert_index] = reward + self.masks[self.insert_index] = mask + self.dones_float[self.insert_index] = done_float + self.next_observations[self.insert_index] = next_observation + self.next_actions[self.insert_index] = next_action + self.mc_returns[self.insert_index] = mc_return + + self.insert_index = (self.insert_index + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + +class D4RLDataset(Dataset): + def __init__(self, + env: gym.Env, + env_name: str, + normalize_reward: bool, + discount: float, + clip_to_eps: bool = False, + eps: float = 1e-5): + d4rl_data = qlearning_dataset(env, env_name, normalize_reward=normalize_reward, discount=discount) + dataset = { + "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), + "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), + "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), + "next_states": jnp.asarray(d4rl_data["next_observations"], dtype=jnp.float32), + "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), + "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), + "mc_returns": jnp.asarray(d4rl_data["mc_returns"], dtype=jnp.float32) + } + + super().__init__(dataset['states'].astype(np.float32), + actions=dataset['actions'].astype(np.float32), + rewards=dataset['rewards'].astype(np.float32), + masks=1.0 - dataset['dones'].astype(np.float32), + dones_float=dataset['dones'].astype(np.float32), + next_observations=dataset['next_states'].astype( + np.float32), + next_actions=dataset["next_actions"], + mc_returns=dataset["mc_returns"], + size=len(dataset['states'])) + + +def concat_batches(b1, b2): + new_batch = {} + for k in b1: + new_batch[k] = np.concatenate((b1[k], b2[k]), axis=0) + return new_batch + + +@chex.dataclass(frozen=True) +class Metrics: + accumulators: Dict[str, Tuple[jax.Array, jax.Array]] + + @staticmethod + def create(metrics: Sequence[str]) -> "Metrics": + init_metrics = {key: (jnp.array([0.0]), jnp.array([0.0])) for key in metrics} + return Metrics(accumulators=init_metrics) + + def update(self, updates: Dict[str, jax.Array]) -> "Metrics": + new_accumulators = deepcopy(self.accumulators) + for key, value in updates.items(): + acc, steps = new_accumulators[key] + new_accumulators[key] = (acc + value, steps + 1) + + return self.replace(accumulators=new_accumulators) + + def compute(self) -> Dict[str, np.ndarray]: + # cumulative_value / total_steps + return {k: np.array(v[0] / v[1]) for k, v in self.accumulators.items()} + + +def normalize( + arr: jax.Array, mean: jax.Array, std: jax.Array, eps: float = 1e-8 +) -> jax.Array: + return (arr - mean) / (std + eps) + + +def make_env(env_name: str, seed: int) -> gym.Env: + env = gym.make(env_name) + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + +def wrap_env( + env: gym.Env, + state_mean: Union[np.ndarray, float] = 0.0, + state_std: Union[np.ndarray, float] = 1.0, + reward_scale: float = 1.0, +) -> gym.Env: + # PEP 8: E731 do not assign a lambda expression, use a def + def normalize_state(state: np.ndarray) -> np.ndarray: + return ( + state - state_mean + ) / state_std # epsilon should be already added in std. + + def scale_reward(reward: float) -> float: + # Please be careful, here reward is multiplied by scale! + return reward_scale * reward + + env = gym.wrappers.TransformObservation(env, normalize_state) + if reward_scale != 1.0: + env = gym.wrappers.TransformReward(env, scale_reward) + return env + + +def make_env_and_dataset(env_name: str, + seed: int, + normalize_reward: bool, + discount: float) -> Tuple[gym.Env, D4RLDataset]: + env = gym.make(env_name) + + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + dataset = D4RLDataset(env, env_name, normalize_reward, discount=discount) + + return env, dataset + + +def is_goal_reached(reward: float, info: Dict) -> bool: + if "goal_achieved" in info: + return info["goal_achieved"] + return reward > 0 # Assuming that reaching target is a positive reward + + +def evaluate( + env: gym.Env, + params: jax.Array, + action_fn: Callable, + num_episodes: int, + seed: int, +) -> np.ndarray: + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + + returns = [] + for _ in trange(num_episodes, desc="Eval", leave=False): + obs, done = env.reset(), False + total_reward = 0.0 + while not done: + action = np.asarray(jax.device_get(action_fn(params, obs))) + obs, reward, done, _ = env.step(action) + total_reward += reward + returns.append(total_reward) + + return np.array(returns) + + +class CriticTrainState(TrainState): + target_params: FrozenDict + + +class ActorTrainState(TrainState): + target_params: FrozenDict + + +@jax.jit +def update_actor( + key: jax.random.PRNGKey, + actor: TrainState, + critic: TrainState, + batch: Dict[str, jax.Array], + beta: float, + tau: float, + normalize_q: bool, + metrics: Metrics, +) -> Tuple[jax.random.PRNGKey, TrainState, jax.Array, Metrics]: + key, random_action_key = jax.random.split(key, 2) + + def actor_loss_fn(params): + actions = actor.apply_fn(params, batch["states"]) + + bc_penalty = ((actions - batch["actions"]) ** 2).sum(-1) + q_values = critic.apply_fn(critic.params, batch["states"], actions).min(0) + # lmbda = 1 + # # if normalize_q: + lmbda = jax.lax.stop_gradient(1 / jax.numpy.abs(q_values).mean()) + + loss = (beta * bc_penalty - lmbda * q_values).mean() + + # logging stuff + random_actions = jax.random.uniform(random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0) + new_metrics = metrics.update({ + "actor_loss": loss, + "bc_mse_policy": bc_penalty.mean(), + "bc_mse_random": ((random_actions - batch["actions"]) ** 2).sum(-1).mean(), + "action_mse": ((actions - batch["actions"]) ** 2).mean() + }) + return loss, new_metrics + + grads, new_metrics = jax.grad(actor_loss_fn, has_aux=True)(actor.params) + new_actor = actor.apply_gradients(grads=grads) + + new_actor = new_actor.replace( + target_params=optax.incremental_update(actor.params, actor.target_params, tau) + ) + new_critic = critic.replace( + target_params=optax.incremental_update(critic.params, critic.target_params, tau) + ) + + return key, new_actor, new_critic, new_metrics + + +def update_critic( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, jax.Array], + gamma: float, + beta: float, + tau: float, + policy_noise: float, + noise_clip: float, + use_calibration: bool, + metrics: Metrics, +) -> Tuple[jax.random.PRNGKey, TrainState, Metrics]: + key, actions_key = jax.random.split(key) + + next_actions = actor.apply_fn(actor.target_params, batch["next_states"]) + noise = jax.numpy.clip( + (jax.random.normal(actions_key, next_actions.shape) * policy_noise), + -noise_clip, + noise_clip, + ) + next_actions = jax.numpy.clip(next_actions + noise, -1, 1) + + bc_penalty = ((next_actions - batch["next_actions"]) ** 2).sum(-1) + next_q = critic.apply_fn(critic.target_params, batch["next_states"], next_actions).min(0) + next_q = next_q - beta * bc_penalty + target_q = jax.lax.cond( + use_calibration, + lambda: jax.numpy.maximum(batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, batch['mc_returns']), + lambda: batch["rewards"] + (1 - batch["dones"]) * gamma * next_q + ) + + def critic_loss_fn(critic_params): + # [N, batch_size] - [1, batch_size] + q = critic.apply_fn(critic_params, batch["states"], batch["actions"]) + q_min = q.min(0).mean() + loss = ((q - target_q[None, ...]) ** 2).mean(1).sum(0) + return loss, q_min + + (loss, q_min), grads = jax.value_and_grad(critic_loss_fn, has_aux=True)(critic.params) + new_critic = critic.apply_gradients(grads=grads) + new_metrics = metrics.update({ + "critic_loss": loss, + "q_min": q_min, + }) + return key, new_critic, new_metrics + + +@jax.jit +def update_td3( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, Any], + metrics: Metrics, + gamma: float, + actor_bc_coef: float, + critic_bc_coef: float, + tau: float, + policy_noise: float, + noise_clip: float, + normalize_q: bool, + use_calibration: bool, +): + key, new_critic, new_metrics = update_critic( + key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics + ) + key, new_actor, new_critic, new_metrics = update_actor(key, actor, + new_critic, batch, actor_bc_coef, tau, normalize_q, + new_metrics) + return key, new_actor, new_critic, new_metrics + + +@jax.jit +def update_td3_no_targets( + key: jax.random.PRNGKey, + actor: TrainState, + critic: CriticTrainState, + batch: Dict[str, Any], + gamma: float, + metrics: Metrics, + actor_bc_coef: float, + critic_bc_coef: float, + tau: float, + policy_noise: float, + noise_clip: float, + use_calibration: bool, +): + key, new_critic, new_metrics = update_critic( + key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics + ) + return key, actor, new_critic, new_metrics + + +def action_fn(actor: TrainState) -> Callable: + @jax.jit + def _action_fn(obs: jax.Array) -> jax.Array: + action = actor.apply_fn(actor.params, obs) + return action + + return _action_fn @pyrallis.wrap() -def main(config: Config): - pyrallis.dump(config, open('run_config.yaml', 'w')) - # # config.actor_bc_coef = config.critic_bc_coef * config.bc_coef_mul - # dict_config = asdict(config) - # dict_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") - # is_env_with_goal = config.dataset_name.startswith(ENVS_WITH_GOAL) - # np.random.seed(config.train_seed) - # random.seed(config.train_seed) - # - # wandb.init( - # config=dict_config, - # project=config.project, - # group=config.group, - # name=config.name, - # id=str(uuid.uuid4()), - # ) - # buffer = ReplayBuffer() - # buffer.create_from_d4rl(config.dataset_name, config.normalize_reward, config.normalize_states) - # - # key = jax.random.PRNGKey(seed=config.train_seed) - # key, actor_key, critic_key = jax.random.split(key, 3) - # - # init_state = buffer.data["states"][0][None, ...] - # init_action = buffer.data["actions"][0][None, ...] - # - # actor_module = DetActor(action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, layernorm=config.actor_ln, - # n_hiddens=config.actor_n_hiddens) - # actor = ActorTrainState.create( - # apply_fn=actor_module.apply, - # params=actor_module.init(actor_key, init_state), - # target_params=actor_module.init(actor_key, init_state), - # tx=optax.adam(learning_rate=config.actor_learning_rate), - # ) - # - # critic_module = EnsembleCritic(hidden_dim=config.hidden_dim, num_critics=2, layernorm=config.critic_ln, - # n_hiddens=config.critic_n_hiddens) - # critic = CriticTrainState.create( - # apply_fn=critic_module.apply, - # params=critic_module.init(critic_key, init_state, init_action), - # target_params=critic_module.init(critic_key, init_state, init_action), - # tx=optax.adam(learning_rate=config.critic_learning_rate), - # ) - # - # update_td3_partial = partial( - # update_td3, gamma=config.gamma, - # actor_bc_coef=config.actor_bc_coef, critic_bc_coef=config.critic_bc_coef, tau=config.tau, - # policy_noise=config.policy_noise, - # noise_clip=config.noise_clip, - # normalize_q=config.normalize_q, - # use_calibration=config.use_calibration, - # ) - # - # update_td3_no_targets_partial = partial( - # update_td3_no_targets, gamma=config.gamma, - # actor_bc_coef=config.actor_bc_coef, critic_bc_coef=config.critic_bc_coef, tau=config.tau, - # policy_noise=config.policy_noise, - # noise_clip=config.noise_clip, - # use_calibration=config.use_calibration, - # ) - # - # def td3_loop_update_step(i, carry): - # key, batch_key = jax.random.split(carry["key"]) - # batch = carry["buffer"].sample_batch(batch_key, batch_size=config.batch_size) - # - # full_update = partial(update_td3_partial, - # key=key, - # actor=carry["actor"], - # critic=carry["critic"], - # batch=batch, - # metrics=carry["metrics"]) - # - # update = partial(update_td3_no_targets_partial, - # key=key, - # actor=carry["actor"], - # critic=carry["critic"], - # batch=batch, - # metrics=carry["metrics"]) - # - # key, new_actor, new_critic, new_metrics = jax.lax.cond(carry["delayed_updates"][i], full_update, update) - # - # # key, new_actor, new_critic, new_metrics = update_func( - # # key=key, - # # actor=carry["actor"], - # # critic=carry["critic"], - # # batch=batch, - # # metrics=carry["metrics"] - # # ) - # carry.update( - # key=key, actor=new_actor, critic=new_critic, metrics=new_metrics - # ) - # return carry - # - # # metrics - # bc_metrics_to_log = [ - # "critic_loss", "q_min", "actor_loss", "batch_entropy", - # "bc_mse_policy", "bc_mse_random", "action_mse" - # ] - # # shared carry for update loops - # carry = { - # "key": key, - # "actor": actor, - # "critic": critic, - # "buffer": buffer, - # "delayed_updates": jax.numpy.equal( - # jax.numpy.arange(config.num_offline_updates + config.num_online_updates) % config.policy_freq, 0 - # ).astype(int) - # } - # - # # Online + offline tuning - # env, dataset = make_env_and_dataset(config.dataset_name, config.train_seed, False, discount=config.gamma) - # eval_env, _ = make_env_and_dataset(config.dataset_name, config.eval_seed, False, discount=config.gamma) - # - # max_steps = env._max_episode_steps - # - # action_dim = env.action_space.shape[0] - # replay_buffer = OnlineReplayBuffer(env.observation_space, action_dim, - # config.replay_buffer_size) - # replay_buffer.initialize_with_dataset(dataset, None) - # online_buffer = OnlineReplayBuffer(env.observation_space, action_dim, config.replay_buffer_size) - # - # online_batch_size = 0 - # offline_batch_size = config.batch_size - # - # observation, done = env.reset(), False - # episode_step = 0 - # goal_achieved = False - # - # @jax.jit - # def actor_action_fn(params, obs): - # return actor.apply_fn(params, obs) - # - # eval_successes = [] - # train_successes = [] - # print("Offline training") - # for i in tqdm.tqdm(range(config.num_online_updates + config.num_offline_updates), smoothing=0.1): - # carry["metrics"] = Metrics.create(bc_metrics_to_log) - # if i == config.num_offline_updates: - # print("Online training") - # - # online_batch_size = int(config.mixing_ratio * config.batch_size) - # offline_batch_size = config.batch_size - online_batch_size - # # Reset optimizers similar to SPOT - # if config.reset_opts: - # actor = actor.replace( - # opt_state=optax.adam(learning_rate=config.actor_learning_rate).init(actor.params) - # ) - # critic = critic.replace( - # opt_state=optax.adam(learning_rate=config.critic_learning_rate).init(critic.params) - # ) - # - # update_td3_partial = partial( - # update_td3, gamma=config.gamma, - # tau=config.tau, - # policy_noise=config.policy_noise, - # noise_clip=config.noise_clip, - # normalize_q=config.normalize_q, - # use_calibration=config.use_calibration, - # ) - # - # update_td3_no_targets_partial = partial( - # update_td3_no_targets, gamma=config.gamma, - # tau=config.tau, - # policy_noise=config.policy_noise, - # noise_clip=config.noise_clip, - # use_calibration=config.use_calibration, - # ) - # online_log = {} - # - # if i >= config.num_offline_updates: - # episode_step += 1 - # action = np.asarray(actor_action_fn(carry["actor"].params, observation)) - # action = np.array( - # [ - # ( - # action - # + np.random.normal(0, 1 * config.expl_noise, size=action_dim) - # ).clip(-1, 1) - # ] - # )[0] - # - # next_observation, reward, done, info = env.step(action) - # if not goal_achieved: - # goal_achieved = is_goal_reached(reward, info) - # next_action = np.asarray(actor_action_fn(carry["actor"].params, next_observation))[0] - # next_action = np.array( - # [ - # ( - # next_action - # + np.random.normal(0, 1 * config.expl_noise, size=action_dim) - # ).clip(-1, 1) - # ] - # )[0] - # - # if not done or 'TimeLimit.truncated' in info: - # mask = 1.0 - # else: - # mask = 0.0 - # real_done = False - # if done and episode_step < max_steps: - # real_done = True - # - # online_buffer.insert(observation, action, reward, mask, - # float(real_done), next_observation, next_action, 0) - # observation = next_observation - # if done: - # train_successes.append(goal_achieved) - # observation, done = env.reset(), False - # episode_step = 0 - # goal_achieved = False - # - # if config.num_offline_updates <= i < config.num_offline_updates + config.num_warmup_steps: - # continue - # - # offline_batch = replay_buffer.sample(offline_batch_size) - # online_batch = online_buffer.sample(online_batch_size) - # batch = concat_batches(offline_batch, online_batch) - # - # if 'antmaze' in config.dataset_name and config.normalize_reward: - # batch["rewards"] *= 100 - # - # ### Update step - # actor_bc_coef = config.actor_bc_coef - # critic_bc_coef = config.critic_bc_coef - # if i >= config.num_offline_updates: - # decay_coef = max(config.min_decay_coef, ( - # config.num_online_updates + config.num_offline_updates - i + config.num_warmup_steps) / config.num_online_updates) - # actor_bc_coef *= decay_coef - # critic_bc_coef *= 0 - # if i % config.policy_freq == 0: - # update_fn = partial(update_td3_partial, - # actor_bc_coef=actor_bc_coef, - # critic_bc_coef=critic_bc_coef, - # key=key, - # actor=carry["actor"], - # critic=carry["critic"], - # batch=batch, - # metrics=carry["metrics"]) - # else: - # update_fn = partial(update_td3_no_targets_partial, - # actor_bc_coef=actor_bc_coef, - # critic_bc_coef=critic_bc_coef, - # key=key, - # actor=carry["actor"], - # critic=carry["critic"], - # batch=batch, - # metrics=carry["metrics"]) - # key, new_actor, new_critic, new_metrics = update_fn() - # carry.update( - # key=key, actor=new_actor, critic=new_critic, metrics=new_metrics - # ) - # - # if i % 1000 == 0: - # mean_metrics = carry["metrics"].compute() - # common = {f"TD3/{k}": v for k, v in mean_metrics.items()} - # common["actor_bc_coef"] = actor_bc_coef - # common["critic_bc_coef"] = critic_bc_coef - # if i < config.num_offline_updates: - # wandb.log({"offline_iter": i, **common}) - # else: - # wandb.log({"online_iter": i - config.num_offline_updates, **common}) - # if i % config.eval_every == 0 or i == config.num_offline_updates + config.num_online_updates - 1 or i == config.num_offline_updates - 1: - # eval_returns, success_rate = evaluate(eval_env, carry["actor"].params, actor_action_fn, - # config.eval_episodes, - # seed=config.eval_seed) - # normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 - # eval_successes.append(success_rate) - # if is_env_with_goal: - # online_log["train/regret"] = np.mean(1 - np.array(train_successes)) - # offline_log = { - # "eval/return_mean": np.mean(eval_returns), - # "eval/return_std": np.std(eval_returns), - # "eval/normalized_score_mean": np.mean(normalized_score), - # "eval/normalized_score_std": np.std(normalized_score), - # "eval/success_rate": success_rate - # } - # offline_log.update(online_log) - # wandb.log(offline_log) +def train(config: Config): + dict_config = asdict(config) + dict_config["mlc_job_name"] = os.environ.get("PLATFORM_JOB_NAME") + is_env_with_goal = config.dataset_name.startswith(ENVS_WITH_GOAL) + np.random.seed(config.train_seed) + random.seed(config.train_seed) + + wandb.init( + config=dict_config, + project=config.project, + group=config.group, + name=config.name, + id=str(uuid.uuid4()), + ) + buffer = ReplayBuffer() + buffer.create_from_d4rl(config.dataset_name, config.normalize_reward, config.normalize_states) + + key = jax.random.PRNGKey(seed=config.train_seed) + key, actor_key, critic_key = jax.random.split(key, 3) + + init_state = buffer.data["states"][0][None, ...] + init_action = buffer.data["actions"][0][None, ...] + + actor_module = DetActor(action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, layernorm=config.actor_ln, + n_hiddens=config.actor_n_hiddens) + actor = ActorTrainState.create( + apply_fn=actor_module.apply, + params=actor_module.init(actor_key, init_state), + target_params=actor_module.init(actor_key, init_state), + tx=optax.adam(learning_rate=config.actor_learning_rate), + ) + + critic_module = EnsembleCritic(hidden_dim=config.hidden_dim, num_critics=2, layernorm=config.critic_ln, + n_hiddens=config.critic_n_hiddens) + critic = CriticTrainState.create( + apply_fn=critic_module.apply, + params=critic_module.init(critic_key, init_state, init_action), + target_params=critic_module.init(critic_key, init_state, init_action), + tx=optax.adam(learning_rate=config.critic_learning_rate), + ) + + # metrics + bc_metrics_to_log = [ + "critic_loss", "q_min", "actor_loss", "batch_entropy", + "bc_mse_policy", "bc_mse_random", "action_mse" + ] + # shared carry for update loops + carry = { + "key": key, + "actor": actor, + "critic": critic, + "buffer": buffer, + "delayed_updates": jax.numpy.equal( + jax.numpy.arange(config.num_offline_updates + config.num_online_updates) % config.policy_freq, 0 + ).astype(int) + } + + # Online + offline tuning + env, dataset = make_env_and_dataset(config.dataset_name, config.train_seed, False, discount=config.gamma) + eval_env, _ = make_env_and_dataset(config.dataset_name, config.eval_seed, False, discount=config.gamma) + + max_steps = env._max_episode_steps + + action_dim = env.action_space.shape[0] + replay_buffer = OnlineReplayBuffer(env.observation_space, action_dim, + config.replay_buffer_size) + replay_buffer.initialize_with_dataset(dataset, None) + online_buffer = OnlineReplayBuffer(env.observation_space, action_dim, config.replay_buffer_size) + + online_batch_size = 0 + offline_batch_size = config.batch_size + + observation, done = env.reset(), False + episode_step = 0 + goal_achieved = False + + @jax.jit + def actor_action_fn(params, obs): + return actor.apply_fn(params, obs) + + eval_successes = [] + train_successes = [] + print("Offline training") + for i in tqdm.tqdm(range(config.num_online_updates + config.num_offline_updates), smoothing=0.1): + carry["metrics"] = Metrics.create(bc_metrics_to_log) + if i == config.num_offline_updates: + print("Online training") + + online_batch_size = int(config.mixing_ratio * config.batch_size) + offline_batch_size = config.batch_size - online_batch_size + # Reset optimizers similar to SPOT + if config.reset_opts: + actor = actor.replace( + opt_state=optax.adam(learning_rate=config.actor_learning_rate).init(actor.params) + ) + critic = critic.replace( + opt_state=optax.adam(learning_rate=config.critic_learning_rate).init(critic.params) + ) + + update_td3_partial = partial( + update_td3, gamma=config.gamma, + tau=config.tau, + policy_noise=config.policy_noise, + noise_clip=config.noise_clip, + normalize_q=config.normalize_q, + use_calibration=config.use_calibration, + ) + + update_td3_no_targets_partial = partial( + update_td3_no_targets, gamma=config.gamma, + tau=config.tau, + policy_noise=config.policy_noise, + noise_clip=config.noise_clip, + use_calibration=config.use_calibration, + ) + online_log = {} + + if i >= config.num_offline_updates: + episode_step += 1 + action = np.asarray(actor_action_fn(carry["actor"].params, observation)) + action = np.array( + [ + ( + action + + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + ).clip(-1, 1) + ] + )[0] + + next_observation, reward, done, info = env.step(action) + if not goal_achieved: + goal_achieved = is_goal_reached(reward, info) + next_action = np.asarray(actor_action_fn(carry["actor"].params, next_observation))[0] + next_action = np.array( + [ + ( + next_action + + np.random.normal(0, 1 * config.expl_noise, size=action_dim) + ).clip(-1, 1) + ] + )[0] + + if not done or 'TimeLimit.truncated' in info: + mask = 1.0 + else: + mask = 0.0 + real_done = False + if done and episode_step < max_steps: + real_done = True + + online_buffer.insert(observation, action, reward, mask, + float(real_done), next_observation, next_action, 0) + observation = next_observation + if done: + train_successes.append(goal_achieved) + observation, done = env.reset(), False + episode_step = 0 + goal_achieved = False + + if config.num_offline_updates <= i < config.num_offline_updates + config.num_warmup_steps: + continue + + offline_batch = replay_buffer.sample(offline_batch_size) + online_batch = online_buffer.sample(online_batch_size) + batch = concat_batches(offline_batch, online_batch) + + if 'antmaze' in config.dataset_name and config.normalize_reward: + batch["rewards"] *= 100 + + ### Update step + actor_bc_coef = config.actor_bc_coef + critic_bc_coef = config.critic_bc_coef + if i >= config.num_offline_updates: + decay_coef = max(config.min_decay_coef, ( + config.num_online_updates + config.num_offline_updates - i + config.num_warmup_steps) / config.num_online_updates) + actor_bc_coef *= decay_coef + critic_bc_coef *= 0 + if i % config.policy_freq == 0: + update_fn = partial(update_td3_partial, + actor_bc_coef=actor_bc_coef, + critic_bc_coef=critic_bc_coef, + key=key, + actor=carry["actor"], + critic=carry["critic"], + batch=batch, + metrics=carry["metrics"]) + else: + update_fn = partial(update_td3_no_targets_partial, + actor_bc_coef=actor_bc_coef, + critic_bc_coef=critic_bc_coef, + key=key, + actor=carry["actor"], + critic=carry["critic"], + batch=batch, + metrics=carry["metrics"]) + key, new_actor, new_critic, new_metrics = update_fn() + carry.update( + key=key, actor=new_actor, critic=new_critic, metrics=new_metrics + ) + + if i % 1000 == 0: + mean_metrics = carry["metrics"].compute() + common = {f"ReBRAC/{k}": v for k, v in mean_metrics.items()} + common["actor_bc_coef"] = actor_bc_coef + common["critic_bc_coef"] = critic_bc_coef + if i < config.num_offline_updates: + wandb.log({"offline_iter": i, **common}) + else: + wandb.log({"online_iter": i - config.num_offline_updates, **common}) + if i % config.eval_every == 0 or i == config.num_offline_updates + config.num_online_updates - 1 or i == config.num_offline_updates - 1: + eval_returns, success_rate = evaluate(eval_env, carry["actor"].params, actor_action_fn, + config.eval_episodes, + seed=config.eval_seed) + normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 + eval_successes.append(success_rate) + if is_env_with_goal: + online_log["train/regret"] = np.mean(1 - np.array(train_successes)) + offline_log = { + "eval/return_mean": np.mean(eval_returns), + "eval/return_std": np.std(eval_returns), + "eval/normalized_score_mean": np.mean(normalized_score), + "eval/normalized_score_std": np.std(normalized_score), + "eval/success_rate": success_rate + } + offline_log.update(online_log) + wandb.log(offline_log) if __name__ == "__main__": - main() + train() From ea2fa11ef3a4b8a0e8596a18a027a781881e8b23 Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Thu, 17 Aug 2023 21:23:27 +0200 Subject: [PATCH 03/15] linter --- algorithms/finetune/rebrac.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py index 8dce4239..8a0ccd52 100644 --- a/algorithms/finetune/rebrac.py +++ b/algorithms/finetune/rebrac.py @@ -925,7 +925,8 @@ def actor_action_fn(params, obs): critic_bc_coef = config.critic_bc_coef if i >= config.num_offline_updates: decay_coef = max(config.min_decay_coef, ( - config.num_online_updates + config.num_offline_updates - i + config.num_warmup_steps) / config.num_online_updates) + config.num_online_updates + config.num_offline_updates - i + config.num_warmup_steps) / config.num_online_updates + ) actor_bc_coef *= decay_coef critic_bc_coef *= 0 if i % config.policy_freq == 0: From 61ebf4a9994cbbc96f13ca1e5ddf9d2934bc5344 Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Thu, 17 Aug 2023 21:44:40 +0200 Subject: [PATCH 04/15] fix linter --- algorithms/finetune/rebrac.py | 131 ++++++++++++++++++++++------------ 1 file changed, 87 insertions(+), 44 deletions(-) diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py index 8a0ccd52..68180f18 100644 --- a/algorithms/finetune/rebrac.py +++ b/algorithms/finetune/rebrac.py @@ -3,30 +3,27 @@ os.environ["TF_CUDNN_DETERMINISTIC"] = "1" import math -import wandb +import random import uuid from copy import deepcopy -import pyrallis -import random +from dataclasses import asdict, dataclass +from functools import partial +from typing import Any, Callable, Dict, Sequence, Tuple, Union import chex import d4rl # noqa +import flax.linen as nn import gym import jax +import jax.numpy as jnp import numpy as np import optax +import pyrallis import tqdm -from typing import Sequence, Union - -from functools import partial -from dataclasses import dataclass, asdict +import wandb from flax.core import FrozenDict -from typing import Dict, Tuple, Any, Callable -import flax.linen as nn -import jax.numpy as jnp -from tqdm.auto import trange - from flax.training.train_state import TrainState +from tqdm.auto import trange ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") @@ -404,7 +401,8 @@ def __init__(self, observation_space: gym.spaces.Box, action_dim: int, def initialize_with_dataset(self, dataset: Dataset, num_samples=None): - assert self.insert_index == 0, 'Can insert a batch online in an empty replay buffer.' + assert self.insert_index == 0, \ + 'Can insert a batch online in an empty replay buffer.' dataset_size = len(dataset.observations) @@ -412,7 +410,8 @@ def initialize_with_dataset(self, dataset: Dataset, num_samples = dataset_size else: num_samples = min(dataset_size, num_samples) - assert self.capacity >= num_samples, 'Dataset cannot be larger than the replay buffer capacity.' + assert self.capacity >= num_samples, \ + 'Dataset cannot be larger than the replay buffer capacity.' if num_samples < dataset_size: perm = np.random.permutation(dataset_size) @@ -459,12 +458,16 @@ def __init__(self, discount: float, clip_to_eps: bool = False, eps: float = 1e-5): - d4rl_data = qlearning_dataset(env, env_name, normalize_reward=normalize_reward, discount=discount) + d4rl_data = qlearning_dataset( + env, env_name, normalize_reward=normalize_reward, discount=discount + ) dataset = { "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), "rewards": jnp.asarray(d4rl_data["rewards"], dtype=jnp.float32), - "next_states": jnp.asarray(d4rl_data["next_observations"], dtype=jnp.float32), + "next_states": jnp.asarray( + d4rl_data["next_observations"], dtype=jnp.float32 + ), "next_actions": jnp.asarray(d4rl_data["next_actions"], dtype=jnp.float32), "dones": jnp.asarray(d4rl_data["terminals"], dtype=jnp.float32), "mc_returns": jnp.asarray(d4rl_data["mc_returns"], dtype=jnp.float32) @@ -625,7 +628,9 @@ def actor_loss_fn(params): loss = (beta * bc_penalty - lmbda * q_values).mean() # logging stuff - random_actions = jax.random.uniform(random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0) + random_actions = jax.random.uniform( + random_action_key, shape=batch["actions"].shape, minval=-1.0, maxval=1.0 + ) new_metrics = metrics.update({ "actor_loss": loss, "bc_mse_policy": bc_penalty.mean(), @@ -671,11 +676,16 @@ def update_critic( next_actions = jax.numpy.clip(next_actions + noise, -1, 1) bc_penalty = ((next_actions - batch["next_actions"]) ** 2).sum(-1) - next_q = critic.apply_fn(critic.target_params, batch["next_states"], next_actions).min(0) + next_q = critic.apply_fn( + critic.target_params, batch["next_states"], next_actions + ).min(0) next_q = next_q - beta * bc_penalty target_q = jax.lax.cond( use_calibration, - lambda: jax.numpy.maximum(batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, batch['mc_returns']), + lambda: jax.numpy.maximum( + batch["rewards"] + (1 - batch["dones"]) * gamma * next_q, + batch['mc_returns'] + ), lambda: batch["rewards"] + (1 - batch["dones"]) * gamma * next_q ) @@ -686,7 +696,9 @@ def critic_loss_fn(critic_params): loss = ((q - target_q[None, ...]) ** 2).mean(1).sum(0) return loss, q_min - (loss, q_min), grads = jax.value_and_grad(critic_loss_fn, has_aux=True)(critic.params) + (loss, q_min), grads = jax.value_and_grad( + critic_loss_fn, has_aux=True + )(critic.params) new_critic = critic.apply_gradients(grads=grads) new_metrics = metrics.update({ "critic_loss": loss, @@ -712,11 +724,13 @@ def update_td3( use_calibration: bool, ): key, new_critic, new_metrics = update_critic( - key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics + key, actor, critic, batch, gamma, critic_bc_coef, tau, + policy_noise, noise_clip, use_calibration, metrics + ) + key, new_actor, new_critic, new_metrics = update_actor( + key, actor, new_critic, batch, actor_bc_coef, tau, + normalize_q, new_metrics ) - key, new_actor, new_critic, new_metrics = update_actor(key, actor, - new_critic, batch, actor_bc_coef, tau, normalize_q, - new_metrics) return key, new_actor, new_critic, new_metrics @@ -736,7 +750,8 @@ def update_td3_no_targets( use_calibration: bool, ): key, new_critic, new_metrics = update_critic( - key, actor, critic, batch, gamma, critic_bc_coef, tau, policy_noise, noise_clip, use_calibration, metrics + key, actor, critic, batch, gamma, critic_bc_coef, tau, + policy_noise, noise_clip, use_calibration, metrics ) return key, actor, new_critic, new_metrics @@ -766,7 +781,9 @@ def train(config: Config): id=str(uuid.uuid4()), ) buffer = ReplayBuffer() - buffer.create_from_d4rl(config.dataset_name, config.normalize_reward, config.normalize_states) + buffer.create_from_d4rl( + config.dataset_name, config.normalize_reward, config.normalize_states + ) key = jax.random.PRNGKey(seed=config.train_seed) key, actor_key, critic_key = jax.random.split(key, 3) @@ -774,8 +791,10 @@ def train(config: Config): init_state = buffer.data["states"][0][None, ...] init_action = buffer.data["actions"][0][None, ...] - actor_module = DetActor(action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, layernorm=config.actor_ln, - n_hiddens=config.actor_n_hiddens) + actor_module = DetActor( + action_dim=init_action.shape[-1], hidden_dim=config.hidden_dim, + layernorm=config.actor_ln, n_hiddens=config.actor_n_hiddens + ) actor = ActorTrainState.create( apply_fn=actor_module.apply, params=actor_module.init(actor_key, init_state), @@ -783,8 +802,10 @@ def train(config: Config): tx=optax.adam(learning_rate=config.actor_learning_rate), ) - critic_module = EnsembleCritic(hidden_dim=config.hidden_dim, num_critics=2, layernorm=config.critic_ln, - n_hiddens=config.critic_n_hiddens) + critic_module = EnsembleCritic( + hidden_dim=config.hidden_dim, num_critics=2, + layernorm=config.critic_ln, n_hiddens=config.critic_n_hiddens + ) critic = CriticTrainState.create( apply_fn=critic_module.apply, params=critic_module.init(critic_key, init_state, init_action), @@ -804,13 +825,19 @@ def train(config: Config): "critic": critic, "buffer": buffer, "delayed_updates": jax.numpy.equal( - jax.numpy.arange(config.num_offline_updates + config.num_online_updates) % config.policy_freq, 0 + jax.numpy.arange( + config.num_offline_updates + config.num_online_updates + ) % config.policy_freq, 0 ).astype(int) } # Online + offline tuning - env, dataset = make_env_and_dataset(config.dataset_name, config.train_seed, False, discount=config.gamma) - eval_env, _ = make_env_and_dataset(config.dataset_name, config.eval_seed, False, discount=config.gamma) + env, dataset = make_env_and_dataset( + config.dataset_name, config.train_seed, False, discount=config.gamma + ) + eval_env, _ = make_env_and_dataset( + config.dataset_name, config.eval_seed, False, discount=config.gamma + ) max_steps = env._max_episode_steps @@ -818,7 +845,9 @@ def train(config: Config): replay_buffer = OnlineReplayBuffer(env.observation_space, action_dim, config.replay_buffer_size) replay_buffer.initialize_with_dataset(dataset, None) - online_buffer = OnlineReplayBuffer(env.observation_space, action_dim, config.replay_buffer_size) + online_buffer = OnlineReplayBuffer( + env.observation_space, action_dim, config.replay_buffer_size + ) online_batch_size = 0 offline_batch_size = config.batch_size @@ -834,7 +863,10 @@ def actor_action_fn(params, obs): eval_successes = [] train_successes = [] print("Offline training") - for i in tqdm.tqdm(range(config.num_online_updates + config.num_offline_updates), smoothing=0.1): + for i in tqdm.tqdm( + range(config.num_online_updates + config.num_offline_updates), + smoothing=0.1 + ): carry["metrics"] = Metrics.create(bc_metrics_to_log) if i == config.num_offline_updates: print("Online training") @@ -883,7 +915,9 @@ def actor_action_fn(params, obs): next_observation, reward, done, info = env.step(action) if not goal_achieved: goal_achieved = is_goal_reached(reward, info) - next_action = np.asarray(actor_action_fn(carry["actor"].params, next_observation))[0] + next_action = np.asarray( + actor_action_fn(carry["actor"].params, next_observation) + )[0] next_action = np.array( [ ( @@ -910,7 +944,9 @@ def actor_action_fn(params, obs): episode_step = 0 goal_achieved = False - if config.num_offline_updates <= i < config.num_offline_updates + config.num_warmup_steps: + if config.num_offline_updates <= \ + i < \ + config.num_offline_updates + config.num_warmup_steps: continue offline_batch = replay_buffer.sample(offline_batch_size) @@ -924,9 +960,12 @@ def actor_action_fn(params, obs): actor_bc_coef = config.actor_bc_coef critic_bc_coef = config.critic_bc_coef if i >= config.num_offline_updates: - decay_coef = max(config.min_decay_coef, ( - config.num_online_updates + config.num_offline_updates - i + config.num_warmup_steps) / config.num_online_updates - ) + lin_coef = ( + config.num_online_updates + + config.num_offline_updates - + i + config.num_warmup_steps + ) / config.num_online_updates + decay_coef = max(config.min_decay_coef, lin_coef) actor_bc_coef *= decay_coef critic_bc_coef *= 0 if i % config.policy_freq == 0: @@ -961,10 +1000,14 @@ def actor_action_fn(params, obs): wandb.log({"offline_iter": i, **common}) else: wandb.log({"online_iter": i - config.num_offline_updates, **common}) - if i % config.eval_every == 0 or i == config.num_offline_updates + config.num_online_updates - 1 or i == config.num_offline_updates - 1: - eval_returns, success_rate = evaluate(eval_env, carry["actor"].params, actor_action_fn, - config.eval_episodes, - seed=config.eval_seed) + if i % config.eval_every == 0 or\ + i == config.num_offline_updates + config.num_online_updates - 1 or\ + i == config.num_offline_updates - 1: + eval_returns, success_rate = evaluate( + eval_env, carry["actor"].params, actor_action_fn, + config.eval_episodes, + seed=config.eval_seed + ) normalized_score = eval_env.get_normalized_score(eval_returns) * 100.0 eval_successes.append(success_rate) if is_env_with_goal: From 94aa863b6bee50fd67b171e6d19a81488eefff8e Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Sat, 19 Aug 2023 14:48:57 +0200 Subject: [PATCH 05/15] Fix --- algorithms/finetune/rebrac.py | 157 +++++++++++++++++++++++++--------- 1 file changed, 115 insertions(+), 42 deletions(-) diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py index 68180f18..f5311b21 100644 --- a/algorithms/finetune/rebrac.py +++ b/algorithms/finetune/rebrac.py @@ -20,11 +20,12 @@ import optax import pyrallis import tqdm -import wandb from flax.core import FrozenDict from flax.training.train_state import TrainState from tqdm.auto import trange +import wandb + ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") default_kernel_init = nn.initializers.lecun_normal() @@ -207,47 +208,113 @@ def __call__(self, state: jax.Array, action: jax.Array) -> jax.Array: return q_values +def calc_return_to_go(is_sparse_reward, rewards, terminals, gamma): + """ + A config dict for getting the default high/low rewrd values for each envs + This is used in calc_return_to_go func in sampler.py and replay_buffer.py + """ + if len(rewards) == 0: + return [] + reward_neg = 0 + if is_sparse_reward and np.all(np.array(rewards) == reward_neg): + """ + If the env has sparse reward and the trajectory is all negative rewards, + we use r / (1-gamma) as return to go. + For exapmle, if gamma = 0.99 and the rewards = [-1, -1, -1], + then return_to_go = [-100, -100, -100] + """ + # assuming failure reward is negative + # use r / (1-gamma) for negative trajctory + return_to_go = [float(reward_neg / (1 - gamma))] * len(rewards) + else: + return_to_go = [0] * len(rewards) + prev_return = 0 + for i in range(len(rewards)): + return_to_go[-i - 1] = \ + rewards[-i - 1] + gamma * prev_return * (1 - terminals[-i - 1]) + prev_return = return_to_go[-i - 1] + + return return_to_go + + def qlearning_dataset( - env: gym.Env, - dataset: Dict = None, - terminate_on_end: bool = False, - **kwargs, -) -> Dict: + env, dataset_name, + normalize_reward=False, dataset=None, + terminate_on_end=False, discount=0.99, **kwargs +): + """ + Returns datasets formatted for use by standard Q-learning algorithms, + with observations, actions, next_observations, next_actins, rewards, + and a terminal flag. + Args: + env: An OfflineEnv object. + dataset: An optional dataset to pass in for processing. If None, + the dataset will default to env.get_dataset() + terminate_on_end (bool): Set done=True on the last timestep + in a trajectory. Default is False, and will discard the + last timestep in each trajectory. + **kwargs: Arguments to pass to env.get_dataset(). + Returns: + A dictionary containing keys: + observations: An N x dim_obs array of observations. + actions: An N x dim_action array of actions. + next_observations: An N x dim_obs array of next observations. + next_actions: An N x dim_action array of next actions. + rewards: An N-dim float array of rewards. + terminals: An N-dim boolean array of "done" or episode termination flags. + """ if dataset is None: dataset = env.get_dataset(**kwargs) - - N = dataset["rewards"].shape[0] + if normalize_reward: + dataset['rewards'] = ReplayBuffer.normalize_reward( + dataset_name, dataset['rewards'] + ) + N = dataset['rewards'].shape[0] + is_sparse = "antmaze" in dataset_name obs_ = [] next_obs_ = [] action_ = [] next_action_ = [] reward_ = [] done_ = [] - + mc_returns_ = [] + print("SIZE", N) # The newer version of the dataset adds an explicit # timeouts field. Keep old method for backwards compatability. - use_timeouts = "timeouts" in dataset + use_timeouts = 'timeouts' in dataset episode_step = 0 + episode_rewards = [] + episode_terminals = [] for i in range(N - 1): - obs = dataset["observations"][i].astype(np.float32) - new_obs = dataset["observations"][i + 1].astype(np.float32) - action = dataset["actions"][i].astype(np.float32) - new_action = dataset["actions"][i + 1].astype(np.float32) - reward = dataset["rewards"][i].astype(np.float32) - done_bool = bool(dataset["terminals"][i]) + if episode_step == 0: + episode_rewards = [] + episode_terminals = [] + + obs = dataset['observations'][i].astype(np.float32) + new_obs = dataset['observations'][i + 1].astype(np.float32) + action = dataset['actions'][i].astype(np.float32) + new_action = dataset['actions'][i + 1].astype(np.float32) + reward = dataset['rewards'][i].astype(np.float32) + done_bool = bool(dataset['terminals'][i]) if use_timeouts: - final_timestep = dataset["timeouts"][i] + final_timestep = dataset['timeouts'][i] else: - final_timestep = episode_step == env._max_episode_steps - 1 + final_timestep = (episode_step == env._max_episode_steps - 1) if (not terminate_on_end) and final_timestep: - # Skip this transition + # Skip this transition and don't apply terminals on the last step of episode episode_step = 0 + mc_returns_ += calc_return_to_go( + is_sparse, episode_rewards, episode_terminals, discount + ) continue if done_bool or final_timestep: episode_step = 0 + episode_rewards.append(reward) + episode_terminals.append(done_bool) + obs_.append(obs) next_obs_.append(new_obs) action_.append(action) @@ -255,14 +322,19 @@ def qlearning_dataset( reward_.append(reward) done_.append(done_bool) episode_step += 1 - + if episode_step != 0: + mc_returns_ += calc_return_to_go( + is_sparse, episode_rewards, episode_terminals, discount + ) + assert np.array(mc_returns_).shape == np.array(reward_).shape return { - "observations": np.array(obs_), - "actions": np.array(action_), - "next_observations": np.array(next_obs_), - "next_actions": np.array(next_action_), - "rewards": np.array(reward_), - "terminals": np.array(done_), + 'observations': np.array(obs_), + 'actions': np.array(action_), + 'next_observations': np.array(next_obs_), + 'next_actions': np.array(next_action_), + 'rewards': np.array(reward_), + 'terminals': np.array(done_), + 'mc_returns': np.array(mc_returns_), } @@ -288,7 +360,7 @@ def create_from_d4rl( normalize_reward: bool = False, is_normalize: bool = False, ): - d4rl_data = qlearning_dataset(gym.make(dataset_name)) + d4rl_data = qlearning_dataset(gym.make(dataset_name), dataset_name) buffer = { "states": jnp.asarray(d4rl_data["observations"], dtype=jnp.float32), "actions": jnp.asarray(d4rl_data["actions"], dtype=jnp.float32), @@ -451,13 +523,13 @@ def insert(self, observation: np.ndarray, action: np.ndarray, class D4RLDataset(Dataset): - def __init__(self, - env: gym.Env, - env_name: str, - normalize_reward: bool, - discount: float, - clip_to_eps: bool = False, - eps: float = 1e-5): + def __init__( + self, + env: gym.Env, + env_name: str, + normalize_reward: bool, + discount: float, + ): d4rl_data = qlearning_dataset( env, env_name, normalize_reward=normalize_reward, discount=discount ) @@ -572,27 +644,28 @@ def is_goal_reached(reward: float, info: Dict) -> bool: def evaluate( - env: gym.Env, - params: jax.Array, - action_fn: Callable, - num_episodes: int, - seed: int, -) -> np.ndarray: + env: gym.Env, params, action_fn: Callable, num_episodes: int, seed: int +) -> Tuple[np.ndarray, np.ndarray]: env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) returns = [] + successes = [] for _ in trange(num_episodes, desc="Eval", leave=False): obs, done = env.reset(), False + goal_achieved = False total_reward = 0.0 while not done: action = np.asarray(jax.device_get(action_fn(params, obs))) - obs, reward, done, _ = env.step(action) + obs, reward, done, info = env.step(action) total_reward += reward + if not goal_achieved: + goal_achieved = is_goal_reached(reward, info) + successes.append(float(goal_achieved)) returns.append(total_reward) - return np.array(returns) + return np.array(returns), np.mean(successes) class CriticTrainState(TrainState): From 73e1747101da7e63c88e08341a9069811f9cff33 Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Sat, 19 Aug 2023 21:12:44 +0200 Subject: [PATCH 06/15] Fix configs --- configs/finetune/rebrac/antmaze/large_diverse_v2.yaml | 2 +- configs/finetune/rebrac/antmaze/large_play_v2.yaml | 2 +- configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml | 2 +- configs/finetune/rebrac/antmaze/medium_play_v2.yaml | 2 +- configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml | 2 +- configs/finetune/rebrac/antmaze/umaze_v2.yaml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml index 4da606f8..cd208878 100644 --- a/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml +++ b/configs/finetune/rebrac/antmaze/large_diverse_v2.yaml @@ -20,7 +20,7 @@ mixing_ratio: 0.5 name: rebrac-finetune noise_clip: 0.5 normalize_q: true -normalize_reward: false +normalize_reward: true normalize_states: false num_offline_updates: 1000000 num_online_updates: 1000000 diff --git a/configs/finetune/rebrac/antmaze/large_play_v2.yaml b/configs/finetune/rebrac/antmaze/large_play_v2.yaml index 0210a3ab..77e56655 100644 --- a/configs/finetune/rebrac/antmaze/large_play_v2.yaml +++ b/configs/finetune/rebrac/antmaze/large_play_v2.yaml @@ -20,7 +20,7 @@ mixing_ratio: 0.5 name: rebrac-finetune noise_clip: 0.5 normalize_q: true -normalize_reward: false +normalize_reward: true normalize_states: false num_offline_updates: 1000000 num_online_updates: 1000000 diff --git a/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml index 341292eb..d5d56435 100644 --- a/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml +++ b/configs/finetune/rebrac/antmaze/medium_diverse_v2.yaml @@ -20,7 +20,7 @@ mixing_ratio: 0.5 name: rebrac-finetune noise_clip: 0.5 normalize_q: true -normalize_reward: false +normalize_reward: true normalize_states: false num_offline_updates: 1000000 num_online_updates: 1000000 diff --git a/configs/finetune/rebrac/antmaze/medium_play_v2.yaml b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml index 4cf200be..e1c4966a 100644 --- a/configs/finetune/rebrac/antmaze/medium_play_v2.yaml +++ b/configs/finetune/rebrac/antmaze/medium_play_v2.yaml @@ -20,7 +20,7 @@ mixing_ratio: 0.5 name: rebrac-finetune noise_clip: 0.5 normalize_q: true -normalize_reward: false +normalize_reward: true normalize_states: false num_offline_updates: 1000000 num_online_updates: 1000000 diff --git a/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml index f2639b9c..41fdf292 100644 --- a/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml +++ b/configs/finetune/rebrac/antmaze/umaze_diverse_v2.yaml @@ -20,7 +20,7 @@ mixing_ratio: 0.5 name: rebrac-finetune noise_clip: 0.5 normalize_q: true -normalize_reward: false +normalize_reward: true normalize_states: false num_offline_updates: 1000000 num_online_updates: 1000000 diff --git a/configs/finetune/rebrac/antmaze/umaze_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_v2.yaml index e173e309..32cd2149 100644 --- a/configs/finetune/rebrac/antmaze/umaze_v2.yaml +++ b/configs/finetune/rebrac/antmaze/umaze_v2.yaml @@ -20,7 +20,7 @@ mixing_ratio: 0.5 name: rebrac-finetune noise_clip: 0.5 normalize_q: true -normalize_reward: false +normalize_reward: true normalize_states: false num_offline_updates: 1000000 num_online_updates: 1000000 From 125d77c7180b9a8c22a9a6eb7edb1d6cc854a7ba Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Sun, 20 Aug 2023 17:20:17 +0200 Subject: [PATCH 07/15] Fix umaze config --- configs/finetune/rebrac/antmaze/umaze_v2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/finetune/rebrac/antmaze/umaze_v2.yaml b/configs/finetune/rebrac/antmaze/umaze_v2.yaml index 32cd2149..042b5232 100644 --- a/configs/finetune/rebrac/antmaze/umaze_v2.yaml +++ b/configs/finetune/rebrac/antmaze/umaze_v2.yaml @@ -3,7 +3,7 @@ actor_learning_rate: 0.0003 actor_ln: false actor_n_hiddens: 3 batch_size: 256 -critic_bc_coef: 1.0 +critic_bc_coef: 0.002 critic_learning_rate: 0.00005 critic_ln: true critic_n_hiddens: 3 From 72bd251be3bde002a2da940c17a5ca5641fec61a Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Fri, 15 Sep 2023 12:19:54 +0200 Subject: [PATCH 08/15] Add scoring --- README.md | 93 +++++++++++------------ results/bin/finetune_scores.pickle | Bin 705857 -> 741701 bytes results/get_finetune_scores.py | 8 +- results/get_finetune_tables_and_plots.py | 6 +- results/get_finetune_urls.py | 5 ++ results/runs_tables/finetune_urls.csv | 54 +++++++++++-- 6 files changed, 109 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 41ffde38..917b6a60 100644 --- a/README.md +++ b/README.md @@ -32,25 +32,24 @@ docker run --gpus=all -it --rm --name ## Algorithms Implemented -| Algorithm | Variants Implemented | Wandb Report | -|--------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------| ----------- | -| **Offline and Offline-to-Online** | | -| ✅ [Conservative Q-Learning for Offline Reinforcement Learning
(CQL)](https://arxiv.org/abs/2006.04779) | [`offline/cql.py`](algorithms/offline/cql.py)
[`finetune/cql.py`](algorithms/finetune/cql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-CQL--VmlldzoyNzA2MTk5)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-CQL--Vmlldzo0NTQ3NTMz) -| ✅ [Accelerating Online Reinforcement Learning with Offline Datasets
(AWAC)](https://arxiv.org/abs/2006.09359) | [`offline/awac.py`](algorithms/offline/awac.py)
[`finetune/awac.py`](algorithms/finetune/awac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-AWAC--VmlldzoyNzA2MjE3)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-AWAC--VmlldzozODAyNzQz) -| ✅ [Offline Reinforcement Learning with Implicit Q-Learning
(IQL)](https://arxiv.org/abs/2110.06169) | [`offline/iql.py`](algorithms/offline/iql.py)
[`finetune/iql.py`](algorithms/finetune/iql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-IQL--VmlldzoyNzA2MTkx)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-IQL--VmlldzozNzE1MTEy) -| **Offline-to-Online only** | | -| ✅ [Supported Policy Optimization for Offline Reinforcement Learning
(SPOT)](https://arxiv.org/abs/2202.06239) | [`finetune/spot.py`](algorithms/finetune/spot.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-SPOT--VmlldzozODk5MTgx) -| ✅ [Cal-QL: Calibrated Offline RL Pre-Training for Efficient Online Fine-Tuning
(Cal-QL)](https://arxiv.org/abs/2303.05479) | [`finetune/cal_ql.py`](algorithms/finetune/cal_ql.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-Cal-QL--Vmlldzo0NTQ3NDk5) -| **Offline only** | | -| ✅ Behavioral Cloning
(BC) | [`offline/any_percent_bc.py`](algorithms/offline/any_percent_bc.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-BC--VmlldzoyNzA2MjE1) -| ✅ Behavioral Cloning-10%
(BC-10%) | [`offline/any_percent_bc.py`](algorithms/offline/any_percent_bc.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-BC-10---VmlldzoyNzEwMjcx) -| ✅ [A Minimalist Approach to Offline Reinforcement Learning
(TD3+BC)](https://arxiv.org/abs/2106.06860) | [`offline/td3_bc.py`](algorithms/offline/td3_bc.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-TD3-BC--VmlldzoyNzA2MjA0) -| ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling
(DT)](https://arxiv.org/abs/2106.01345) | [`offline/dt.py`](algorithms/offline/dt.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-Decision-Transformer--VmlldzoyNzA2MTk3) -| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(SAC-N)](https://arxiv.org/abs/2110.01548) | [`offline/sac_n.py`](algorithms/offline/sac_n.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-SAC-N--VmlldzoyNzA1NTY1) -| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(EDAC)](https://arxiv.org/abs/2110.01548) | [`offline/edac.py`](algorithms/offline/edac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-EDAC--VmlldzoyNzA5ODUw) -| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2) -| ✅ [Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size
(LB-SAC)](https://arxiv.org/abs/2211.11092) | [`offline/lb_sac.py`](algorithms/offline/lb_sac.py) | [`Offline Gym-MuJoCo`](https://wandb.ai/tlab/CORL/reports/LB-SAC-D4RL-Results--VmlldzozNjIxMDY1) - +| Algorithm | Variants Implemented | Wandb Report | +|--------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| **Offline and Offline-to-Online** | | | +| ✅ [Conservative Q-Learning for Offline Reinforcement Learning
(CQL)](https://arxiv.org/abs/2006.04779) | [`offline/cql.py`](algorithms/offline/cql.py)
[`finetune/cql.py`](algorithms/finetune/cql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-CQL--VmlldzoyNzA2MTk5)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-CQL--Vmlldzo0NTQ3NTMz) | +| ✅ [Accelerating Online Reinforcement Learning with Offline Datasets
(AWAC)](https://arxiv.org/abs/2006.09359) | [`offline/awac.py`](algorithms/offline/awac.py)
[`finetune/awac.py`](algorithms/finetune/awac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-AWAC--VmlldzoyNzA2MjE3)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-AWAC--VmlldzozODAyNzQz) | +| ✅ [Offline Reinforcement Learning with Implicit Q-Learning
(IQL)](https://arxiv.org/abs/2110.06169) | [`offline/iql.py`](algorithms/offline/iql.py)
[`finetune/iql.py`](algorithms/finetune/iql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-IQL--VmlldzoyNzA2MTkx)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-IQL--VmlldzozNzE1MTEy) | +| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py)
[`finetune/rebrac.py`](algorithms/finetune/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-ReBRAC--Vmlldzo1NDAyNjE5) | +| **Offline-to-Online only** | | | +| ✅ [Supported Policy Optimization for Offline Reinforcement Learning
(SPOT)](https://arxiv.org/abs/2202.06239) | [`finetune/spot.py`](algorithms/finetune/spot.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-SPOT--VmlldzozODk5MTgx) | +| ✅ [Cal-QL: Calibrated Offline RL Pre-Training for Efficient Online Fine-Tuning
(Cal-QL)](https://arxiv.org/abs/2303.05479) | [`finetune/cal_ql.py`](algorithms/finetune/cal_ql.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-Cal-QL--Vmlldzo0NTQ3NDk5) | +| **Offline only** | | | +| ✅ Behavioral Cloning
(BC) | [`offline/any_percent_bc.py`](algorithms/offline/any_percent_bc.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-BC--VmlldzoyNzA2MjE1) | +| ✅ Behavioral Cloning-10%
(BC-10%) | [`offline/any_percent_bc.py`](algorithms/offline/any_percent_bc.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-BC-10---VmlldzoyNzEwMjcx) | +| ✅ [A Minimalist Approach to Offline Reinforcement Learning
(TD3+BC)](https://arxiv.org/abs/2106.06860) | [`offline/td3_bc.py`](algorithms/offline/td3_bc.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-TD3-BC--VmlldzoyNzA2MjA0) | +| ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling
(DT)](https://arxiv.org/abs/2106.01345) | [`offline/dt.py`](algorithms/offline/dt.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-Decision-Transformer--VmlldzoyNzA2MTk3) | +| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(SAC-N)](https://arxiv.org/abs/2110.01548) | [`offline/sac_n.py`](algorithms/offline/sac_n.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-SAC-N--VmlldzoyNzA1NTY1) | +| ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(EDAC)](https://arxiv.org/abs/2110.01548) | [`offline/edac.py`](algorithms/offline/edac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-EDAC--VmlldzoyNzA5ODUw) | +| ✅ [Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size
(LB-SAC)](https://arxiv.org/abs/2211.11092) | [`offline/lb_sac.py`](algorithms/offline/lb_sac.py) | [`Offline Gym-MuJoCo`](https://wandb.ai/tlab/CORL/reports/LB-SAC-D4RL-Results--VmlldzozNjIxMDY1) | ## D4RL Benchmarks You can check the links above for learning curves and details. Here, we report reproduced **final** and **best** scores. Note that they differ by a significant margin, and some papers may use different approaches, not making it always explicit which reporting methodology they chose. If you want to re-collect our results in a more structured/nuanced manner, see [`results`](results). @@ -170,42 +169,42 @@ You can check the links above for learning curves and details. Here, we report r ### Offline-to-Online #### Scores -| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL| -|---------------------------|------------|--------|--------|-----|-----| -|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43| -|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12| -|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64| -|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48| -|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79| -|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91| +| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| +|---------------------------|------------|--------|--------|-----|-----|-----| +|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 -> 74.75 ± 42.59| +|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 -> 98.00 ± 2.92| +|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 -> 98.00 ± 1.58| +|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 -> 98.75 ± 0.43| +|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 -> 31.50 ± 33.56| +|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 -> 72.25 ± 41.73| | | | | | | | | | | | -| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33| +| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 -> 78.88| | | | | | | | | | | | -|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12| -|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01| -|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17| -|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04| +|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 -> 138.15 ± 3.22| +|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 -> 102.39 ± 8.27| +|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 -> 124.65 ± 7.37| +|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 -> 6.96 ± 4.59| | | | | | | | | | | | -| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79| +| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79|20.33 -> 93.04| #### Regrets -| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL| -|---------------------------|------------|--------|--------|-----|-----| -|antmaze-umaze-v2|0.04 ± 0.01|0.02 ± 0.00|0.07 ± 0.00|0.02 ± 0.00|0.01 ± 0.00| -|antmaze-umaze-diverse-v2|0.88 ± 0.01|0.09 ± 0.01|0.43 ± 0.11|0.22 ± 0.07|0.05 ± 0.01| -|antmaze-medium-play-v2|1.00 ± 0.00|0.08 ± 0.01|0.09 ± 0.01|0.06 ± 0.00|0.04 ± 0.01| -|antmaze-medium-diverse-v2|1.00 ± 0.00|0.08 ± 0.00|0.10 ± 0.01|0.05 ± 0.01|0.04 ± 0.01| -|antmaze-large-play-v2|1.00 ± 0.00|0.21 ± 0.02|0.34 ± 0.05|0.29 ± 0.07|0.13 ± 0.02| -|antmaze-large-diverse-v2|1.00 ± 0.00|0.21 ± 0.03|0.41 ± 0.03|0.23 ± 0.08|0.13 ± 0.02| +| **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| +|---------------------------|------------|--------|--------|-----|-----|-----| +|antmaze-umaze-v2|0.04 ± 0.01|0.02 ± 0.00|0.07 ± 0.00|0.02 ± 0.00|0.01 ± 0.00|0.11 ± 0.18| +|antmaze-umaze-diverse-v2|0.88 ± 0.01|0.09 ± 0.01|0.43 ± 0.11|0.22 ± 0.07|0.05 ± 0.01|0.04 ± 0.02| +|antmaze-medium-play-v2|1.00 ± 0.00|0.08 ± 0.01|0.09 ± 0.01|0.06 ± 0.00|0.04 ± 0.01|0.03 ± 0.01| +|antmaze-medium-diverse-v2|1.00 ± 0.00|0.08 ± 0.00|0.10 ± 0.01|0.05 ± 0.01|0.04 ± 0.01|0.03 ± 0.00| +|antmaze-large-play-v2|1.00 ± 0.00|0.21 ± 0.02|0.34 ± 0.05|0.29 ± 0.07|0.13 ± 0.02|0.14 ± 0.05| +|antmaze-large-diverse-v2|1.00 ± 0.00|0.21 ± 0.03|0.41 ± 0.03|0.23 ± 0.08|0.13 ± 0.02|0.29 ± 0.39| | | | | | | | | | | | -| **antmaze average** |0.82|0.11|0.24|0.15|0.07| +| **antmaze average** |0.82|0.11|0.24|0.15|0.07|0.11| | | | | | | | | | | | -|pen-cloned-v1|0.46 ± 0.02|0.97 ± 0.00|0.37 ± 0.01|0.58 ± 0.02|0.98 ± 0.01| -|door-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.83 ± 0.03|0.99 ± 0.01|1.00 ± 0.00| -|hammer-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.65 ± 0.10|0.98 ± 0.01|1.00 ± 0.00| -|relocate-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00| +|pen-cloned-v1|0.46 ± 0.02|0.97 ± 0.00|0.37 ± 0.01|0.58 ± 0.02|0.98 ± 0.01|0.08 ± 0.01| +|door-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.83 ± 0.03|0.99 ± 0.01|1.00 ± 0.00|0.19 ± 0.05| +|hammer-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|0.65 ± 0.10|0.98 ± 0.01|1.00 ± 0.00|0.13 ± 0.03| +|relocate-cloned-v1|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|1.00 ± 0.00|0.90 ± 0.06| | | | | | | | | | | | -| **adroit average** |0.86|0.99|0.71|0.89|0.99| +| **adroit average** |0.86|0.99|0.71|0.89|0.99|0.33| ## Citing CORL diff --git a/results/bin/finetune_scores.pickle b/results/bin/finetune_scores.pickle index d60377ec763001d8b678a2bb4c22d82ed4dd82d5..a0a682d367051865f229727df37b46f1e4df7e9d 100644 GIT binary patch delta 36898 zcmc&-30RF=_pj5DBMnq48Hz}CMdqYvG^9x-Lq#EDrBWpFmJm@wIxZnZ>58aiu4F2g zq{uA8bq!H&G${JN=bZPPv)5@Ey5ILb|MT2D`nC4jYp=b2YpuQCH@6_>x6AttBi2Y~ zvpC(kPAY7bTq7+OONOh|jjh5BaF$i!)~w?wb2opLR^=vtkXGSVx^h&wUfL`LuGvRv z4Q{fHY_Ht>?ySWuZoVo@mD?$b{;^Se!H-q8d$2UPPIF~dx$)y^RkS{3ZvHAc)##eu zEc#0(x28ASRhu`jF1M2{PeO)eC-|GTG)tN5Wz9C>f1B$xj@^^RSS7<6nHz7#u4ZwQ z?b$BeconuU*K3Zfa&G=mwh}9`8(WF{&5brgTZ27Qpe(s7yOhym%-$^>y}^-9f3?fq z;3!+C$!&JwFssRQq)n}vMO&%8nL{t;3BBL|OZr|d?$@XDIUsxhb zf3>H-IJp;=bhx)kMoLOL=Xbd?`?+IQ(`%3DOKVfepp~1~(<6=4=x-}|*5m~XHb~`U zs$4kZQS-$##{c8#W{xg5e3=}R!MB)f3@doG)OCPk2KKoHEeqRWGi|#nUTB&_olzioJ%`4LW3SjhZF5$ z4ojCCe^Bl|H@-VtIXC{WTu&J;KMGoM=pi>{XhCh>g7Q1i(()N{s$4o$=~ZRY7%%K{ z{WImJu=qjDgb+`f{;>l0oII;5e=)hVOAaurrp1|`lql1Bl=jim^aRAFayR?X7Mmx? zspJM^$(3}bHpx2sSUOu+C4?vF&6n{^(4+sG&zY}i#VLC2cDN|({FsDw3C;-uXs@@& z<^!kx$Ri+|y-S_zYg0%m(+1m39X*eb(CwvevPpW6NJ!(G+MQmBl$trAZ0P!nX;)Qc zNvNqjHZ+lSmKy6{oij->PQkbvhW1ZLC?U48Fx4xSg!azr7Hu}_BMIF(yL9R98I>fYajDm# zxqA5mh!Qx~@P)O<@#f>C*n^`zZQBoLtnd--O8=YhuY9#mBPBbE3XXBB66Vk5guIG+asSo~PNZpiNUr$ab^V<~KQ0NN8cLQu?fc(@1D@ z$&b=@#CFM?ztR1uA%adgQ@?SsH+xN?K zQflRj5fdx@fz+jMQxkKx0jP)l)EKD<05zY#_w?m54_X%`if!uEUuj8FK!0vPK)(Y+ zcw!={?B>JmSE@3~e8!Er`0;Sl-@Y586GR)7v&Jp9Mw!`O!R}V|aPGhFQF)6c~B9B{p|T$thB7_bl0l0)1e`qh{5r+Btco)Yz!3t`?yn zUxdZ29RJS+K(6Vfbw^$HiwHsSGR=#iPIr2rK1V96a~fcJ+5$jh2d4#>_5_Z-w=Xa& zw;`7tC_Crivw;JcDaEvjIXdi%?7a)5WZVsxHJ+5EJn5s^$M+VM%>}$U)pM2pn2C0z zl=g$(Vb=o2lh6UX;hM5O<48#Vbwj;VgE-Xbh3->)Xn*A@2~nIS=gU@peh%1GYkAID zstEArP?YAPx$Qx8#+e6hI$#P0R$kRma_a=7h8)~pwnPCyN1QbjvzQpBn9TkqX6u^t zW-w*v_~EnD27{49zpjkyege=wR3Ur&pS1u|eYALHw<}vM%WOcBN%Li%^_g*)lnQ(~SUY!n770E2+_CD05hyoD ziw8&_%K=g)Hz&VZTT&z*NLg{@Q^NcgudWeeR-E7$0bPBL19kPSNGS;jD{#MyUT*Yt?bdP-Q@74nul8}uF_7%AqBu2i34vvZb{ z3O(wg|FTi(ONM%IZBlPDrBNiLzHsp8yxC(($nR0j&i&S)PCcU14;`IQZc2*zI^8|G zLK}>%I#s@TbbE7BO3ril^ip3l5~@3XE=R+fWEaP#d!+Ho2LO`cM(kFc;KB@~_^po1 zC9@5tkV;z4C)0c`fgm25=6IYs1u`?0-@6IKFcZ{}yO?XAq%PnVlgezmbC%FV}(2HP*Hf>3d^d1aaAoSjFg zUg|ljGO_;|66(Klc4gL6>NL5vcKXNM!q#PzP<`x@NP`Dbsd)%l>@ObM^gyyNIZ zik-Wo&HFLQIE`$)GTu0|xZ(fiE<3pLPY zE$}u;I(WMyDYdnGZqV3^!vzo%aZ3u?dvs*Ukz&3^7v6Rb??ggWMP8-wH!655cg(d$Kw1_>2B>#n=b4jkSbJO|osf7p(cI!NuB#tN!qMv578Z0$$?n?)>A z)FGmu(jK2NB*e`dcc4l^n}qz9u27Du0q2_XY1JX=TZJi}pW!PTN6AFpnL!T}cl?Eo zv)Bz*UZk?%jEXWJdjM75_!@C=3V>YZ9h&A~Hj^Cqda7dMSPDobR-IjcW;TGbPu=vR zO2+|AZFlo!7e%1#vQDh2o7RjWS zqmm{accMU>VQ%yQfht3 zkU@RA1CAJ;nzrchDWFSV$!Jel6oB+vvhJS^q#;Utm^5tSYI~3ar7y-%BhvZLy`)|85R)}e{Rj%Pk^K~F;ZxMb!=Q1cN<%7cdmlWFaEfIG7pFWLollv1K zZXE|`%zLlRBZ(+0H-Gr#@t`*dm43fVWpW$<^z_lS@_A|iitS*%CwmfrmK@bkkbgvZ zlKP+LDyTFC0f3#02%ppDD^x5AS!O3 z$E!t+fH!kqs(kI41;kF(CO77u0+6kaUC$CI_2vPnndU-s<=N~2x*xh^Ls zbt9K#vHng|WS9yGx%5dhsdxnl$I(nx*tnmO;?qXZ@#lB8d^przi5z`bE_kNLI4csG z)EI7Junp9fE!LwZt?~y@xmMAs;eUa0^S;#Uo4Xkx*YImb2Y2uD1RXzjUiJIPZ-5Xz zzLs?GNC6Gb3ih23fro%pdGB34Jc~gL4}37Jr>#GLOwMs*_1M6Q`_HH*!;c+cTpO!D z+L;YTF1^)hhetDj=2p-7XT1s-Sis8JVlD;rJ9O{BeNzNd$XL&JeMaet2J-Vi_kGN} zotqf_G;c1tdZo{~*g)_mqM}K6;ADa~R-FE_R=!)jz|B{+UwDFhc?WXOQjX;LKT2sJ zY18age75#?rb}b`FPUwe_j?M!~X3^ z=_v^H;Sa1j+nxXswdJ=pxAjv%bZR*+D&7jidO*OFL;7%_+^nn#)b$>sSmY(3kAa6; zGFSkrf0gt>c{G^P%gr}w=0Y&#!Sz*}mt6%=-sBFeKbZhsDnk_$$GH$v&Q_ehi`Ol> z(H)4{*Mx@;d`gH}adIZ_=%Ehl?3usZ*bjXe!qK!jYpXLd|J?^}=h5v{sW-3sY*ad< zN^V5(UZ;e*7;q-+lU5isW0w;t<(V!&>TE5@U+UUv1s`+3CB`82`Ys`X0IGY_C-bx> zGmzm`OV{K^&S0SbCd=;Y$-MyjbFuNFG0MO&-^-EqN1lPi8aYyzlT!$$44{VjU3%?I zn*T9%lW}9xWQxQN2fZBoxdG@;=`g|WVgZ28JoqEzkL3U|^9oRY{tiqT`D~TP$$kKu zIPFksQ7Z6QAmzDSS7#%mOw5=bvuyKRN4Sy3?Fu1F^I#FFN=|B#7u!uEV;$-36o^9%uT_m-QxX zt>#|8w&DhmN_v&lz795HKHZ*&QJ(#QSe5&V%7kVRW?zT&n9CRZv{H=L`0|^aqx7@p?41`K;pzqxbX;CQh>ssiP5sBevL9Cee3*Fk+ z9zZ*{|K9R!BM5biJt3@afdIPbdagXJ!}F~T2=0f`N#0jnc;UI!j_exe2p5wm&Kwtn}6S zia%3EfC^MF$S(ZRHxTMFseg?7wgo^5uh)*Q2n8{=Q{j&Rq3T}XTywYk^hRZAAU3G{ z%;%BiAQUI<;ZQNr02>$bYoz`W4EOsV6Vwjb5;(E1(^PqF71x% z0_eq{=qF3z5&m`dp|VrYfc|ZET{^El0tQ~2S9ImLJ%BbJP+)sOte?1lyl2JERKD{{#2;sivxxMBfASLhMw5~E2 zNKG!)JL#}M6lA<2{bpKzh#xEmr?k!mP|xpxk#A$o-{nPsk#4ECE9T0BAol$-wToL6 zkg~|z|B00XpsdG+nMtsY=-u98Z52=GqR%y1k4u#f)_{>y>$WxZf>+rWZH%W|z6Vmt zF$?8Cq%#mro2&~izp?x5<=&E!Q>D&#-MjN;;4$(*-!h?=s(bK|gz6TCesnCZAR(Fd zcdU~~l#x)VTqM_g#zzu5ncXa3Y~7xW-a#|H#-6&!T*fiHdYjsJj`C}8XP2%!cE0M& zCQ{ku3tdycy!uQ+DsMW@?qv8k39Ylrx?Vg3T=w-iZog>y=1!y)pL!1Up8vR(B1f83 z9*I=_^p1q8lOFZGd=I=nP#9L#WMc)aFe@9m`?C3Sa-d0OX6|v{FC?^VzM@W)m+0jj zUbJOUGnO^<(R(Zq6Eh~RgPQl68(<3#q!cAzIsx8@i(VI^^KKt_V!BOntZE;7@Fou1 zzWP(Z>3Gr-1)WU}d#jG+k@p495-u5S*aVancdAroCCAfZq8n>-*-*`wQHx1s(<_~2 zma^`UP_CiEJht?E5*oEv@x|4WY7(lKyXQv9doYi81&nX7Y*b$T6+GJYy&l%XSn4G? zQb$><>EpP4d4yrvoclGLw^gK6c#7qiwWGk4YxlZD{kbU4``%X>cTX9`uNj;w2eT73JKafz!IO#t+ z4->71U)QQRzwKg=10x6V%ViikBIEA;9fg2 zN+x(=JqX2ZOYFL2?Un-PnyR>Vd&;f?V=o;zX=UzMM_Qq)wzd5W4Zu{J#;5PLT>|=z z^S%tGU)UxG9z5~6#{M}NsFZw{(%-udY??*JIX$_%fJzl>(rdPQ0FLi7SfW@H!H_{r zo6YgjSGw5TB&S`w!RW?{1AmZ^=l%&HC0;>!yyy|0YwqT$Crw!Zrv0#9_s{DEcgc|^ z)(wRp%0Plxd&*WWe=3L`gPLif*KUK%CB+GQbCL@(qhaS8>Mem3aIT6&$a&Yl=xq@b z&eUj{n(74Lwkt!gsfHxu0k5U8>%#X8t0$!>WtVX~+|Pgz zdhtpp_?VA6X+^thvl8aHfSc5ANkhh@ZRt!(eflfdG^6o%UOTeBryiIIbMX7_W-PLt`$;1q>L&H~%)B3y)_(I>Xz(n&1 zk;Uk@e2c?NvlqW2-D7&ya9XlG2-e+I3k%mPGf_^QYXXOjNVL8KDwD;3 zHw#{|rNBFcMepZmnOI#VrNT-GNMz=2=r-SDR^~b0I zN!IJZ6OMCTF1?QU4LrIC((Jfuk0E#+BmGs){2FyIFm0yoSI$Xo(u%?zh7YdI2G0sl z9{+sK#zu6q=jZ=jhVmQNuIeG;O;wb#KR4Ei;Ei&+ZEiO-5HE`xdQi`iGt+VTJ*nd5lQavoF2!!IPN###c4ujp;B)e&}{r*}~ z*~8@8X9pL84eE7Q$H2h{w$Pum|Y04-DU%&qG~UbDfuCXnj-&fySvG+DfSo6c=t zVA%RqN>}=n0m8+dzVG?vD~O4q#>Wr7*#@LOjl0p}Ed%(hOPvTyec=R_ly-JO-Ct(_ z)Xp-&VhzP)baA=)Ht%#h^Ee<`c3c0oRXX@phWDfV$#-sz-^!hu1^KY$-XkTouG!)Am`u<(t8-wCMPaylM%a zWZM*G&!Y-Y_9P=#{&2+peUJ_u}+o=_Mj#UmK_4m&*ih6VpJR8d}J*T8$)L%rK zH%}I1SQZvXC-^K}I#EiBe)n1M`g6xktOdpi8`ekDU98?jnfXc%4i2_0^oRa~YH67$ z6`U~3iX+8*81gLiI43P-#bKeSH;Q_p=nNDU%2IH;IjAr`9ggA}JW;*0EXDr;G(u<_ zMDginp^AT9mfru_s0!Q&Toz0xpeM6^;3xAr3`%I>*ZfOsIp-HC5@0rg%we>J&LNKXqQ$f#Uq)Jd(acoq%hr0K3M9y zAYQjSQx%!rVPqtnh@)PlLS}#I_`|o7h!G+fa1@U+oQV>Ip_p~dL~OI-NC^K>7M~6` zO}u15{rO?2| zKYw>b@wa`ClVJLPUCkdDIZ< zJ2b<+frlKb_!sq3f;JCVflr500Tf;~(xdSO<^4W;;Wi z^KN{Pzx<0~CWGPh#Y0MHsYLvN*+nF1Lp-&qWYUD8*H3?b-~y2m_&R76{14QCqgWIE z-ToI}Gmhe3!;Jvp5Yr=r5RVs}y&R9)~B z5hZa{6o|}wZA@VNL|E`tf|14Z0-h!DjpAwG`yYmu7r8>wI&j6JbpX?FS=3;PgVPK* z0+&S0)^+Km}K~bE{ z(^1(mC@PUqg!gYa-@tSvZNIZZKy#6gaMX)$UaR=SLl`Fk9;-O|9liAY{0}6-OcecJ z81aMqH{9Z1V4}$JxWPa#MbkDhsAt&TPlww=Tj(yvTN11_M4c$XPhtGw#TF+%j^d3o zUb@7E8Qj0&B^XzMXB51c;d=Q@7nBSe9n({*Wrp zgo;ew5qBY7v4l^1TOV$)d7n|F^{A8{^KZp%*OT~tQqhZj|`k<_>{u97j7!k|4PDw7P>8L zXp06ymh69Vi*X|e6-0dFvOhv;{|j9qSaVwx|7%&Az8gAS-e$SJfckG&UHl5Fb>e4E zqFDUl`;E22Rq5%*tXtuX%Z_eC(qSaEO^e`iHl=Sa?x=Tp%C;_W2f(}JU+h>5rc zoQeEPW?uY3+wj%lOvKXxZiJEit1A%J z^5%zQ8F&2>3g!`6iYIFjXd+n(&-XZrhX~G3TtlmiV~YQlC(f9*D8^00{X^cb@qYh= zsb+YY#1ohq8PtHactb)?DV~w`5f*>A)%dP@k%owtf|CI28@_x0hjR|6I8GQ`gM_TG z{@`TBRp4uuP|R#9j*4dk^#=&^ify2UN`ikW?}$5xKAHW?BQCFuGrQE9)Y$&x!uTtw z-O>HmMo}FErT4!V0$Xt~VJi-P$%Rih5Y;PD)r9+RjS8Ei=tvYbK~em{wU~~){RjNT zSBR$~y!_*>0)9E*DcFC>8i1?7S&gH34rxu_Hs8OAsY5=R$Bn>IFLI;A{RftvtT=7S zdO$Hv-|V4V&%Up{i&S0O0;cE+wNBnZlQVlX-#D1)NvV~)z={Dl3lNCnP9kR*PAwKzWk3B>U?6SW$T zN?aDFxEHEdNDYbp7Y}!w%t93~v~U#HOPeOhKTrj}I@}20Owsn?Uc*ucU4OuIG;K2j z{lZ6^*#3~VKjbIla0xEs>x=Eg8mnOn+vaV_yt>Am^hp; zg6MyT60o>O#f`vaaU;Ni8nBk`m9h8(H(8>X!VeU9VB@h04!(@u@AkjAVw~b2f<*@Z z5;cVN2N%ZI(H2?UG;xvu^#@mss{pMzkF^;+Rt(&v)K=CK@yCRp1Z{CtTtLZ6*#(VY zH54@#pp?XPc%AEks=!I`OOKtI{AI|*zl0GaLB&iQgBroCLqdZIDvO{LeK^5s3p@d6 zYk?}qozI9$m?uU-{E1<5SO2FdMQewt|9?7NtJ!$-^FRL&M$oLSrn;<94m!Jrcf_4R zKcx8&j<~FUwGgoXc(bg8DhA;x%DJuXm-yF9_@iBZ-EO`AIH7Q^jiuP8V-Ty-*`?;!7OEg#9o6;)}RXC9@B18ZL_)EM79GKe)x-ZCD*$Lirz< zP{h;G%b;0%p>9BF^{n^2UYQdA`WG|wXN>v5Ka0^84UDX0{K+ynTtHhCgA^sI0vf1Z zWxW5@Muj`0D1Ntvrw*hw=k>q1Ff-dPZ3ig9fVDXJ#Y_b8HxpeUh#rwcaT2s8R{4bC z<$t)>xMG|~_;gNwpp@|GC*Eb_C{!WTpa>KTQM{)Y?GF6?H#~Rr zLlw72(Vi&U3q?oKC@=m6`vGWch7Q1s6x~JgKafVhSb_!dC(Z|C{CS~@C!;7%4ZQ6o z$wl-(I1_OrkW5MkGM+lf>BL-MW|OQxxDohu590`uPP7pKrCED1XWUmCOZp$5@>w{Q z_MOBDm%snKA#nGA5>d{@QB4Mh8t_8eTT)iA|H9J{RADM9%rE_*Lll{Yzoy`fY9O;6 zvH$o{4}b9=k8YeFcyKevP09VoEyiOWpn$%3$;2uU-T&hC1CQF-s6%J{dsNu};wo?> zaD#Ca*DGEI=s(2p&(n|>y5>1;K=~SY{DV;sjG+1Zd#Rzj)*Su+YSpFFguK8~$@s$) z8_qfWXlRML0KZdE#rt3UO(eX|01}AO0nRyu%e?*r~6>Z^Mt7Qo4Pb-*TY4FwI zvH+!7yFc;fQ033lf67{3{4q|H$OAZvm(|heK4_sRl7e;!MOVrRbgj z_P==K;AIt;#kq{5JcD`hFX4FHVo*{+vBlSoo7al9g7}+>Ius9g5NRSNwndht9MS*$ z=Zd}Bz6{CwgBw9wEN-E3;^P`{|I7mQho=?c9lwJB`Sd`BTi`;G;sjr4fm7T&Se^8o~npPcD8w~=}%BG@L z78P#GIQcyyL3xLnG$wv#GqFr^&Urrqn_B90RFe zT9~A)#rK3h?WMUgqHw;8k3=^l_pM9uOui#}hGi-_(V@LV{F^~;7e5t7ru-a-)oDK` zW*l-JfQia8%$n^D8Ycm!b{%6?W9(m`ycYrYzHFrYEh!>qiwTbL1CG5m-WZJEN&XT1 z%l-fnAllYzr^P#iE(A>Y2{7hek~R%MCGC%iPfTAN6)mx;Qs-5 zBrPRaI)o7Q3zCo)l@iGZH9q OS6RBSpsLO6%&xzKhXe`$ diff --git a/results/get_finetune_scores.py b/results/get_finetune_scores.py index dd31ad19..45b379a8 100644 --- a/results/get_finetune_scores.py +++ b/results/get_finetune_scores.py @@ -2,9 +2,10 @@ import pickle import pandas as pd -import wandb from tqdm import tqdm +import wandb + dataframe = pd.read_csv("runs_tables/finetune_urls.csv") api = wandb.Api(timeout=29) @@ -32,9 +33,14 @@ def get_run_scores(run_id, is_dt=False): break for _, row in run.history(keys=[score_key], samples=5000).iterrows(): full_scores.append(row[score_key]) + + for _, row in run.history(keys=["train/regret"], samples=5000).iterrows(): + if "train/regret" in row: + regret = row["train/regret"] for _, row in run.history(keys=["eval/regret"], samples=5000).iterrows(): if "eval/regret" in row: regret = row["eval/regret"] + offline_iters = len(full_scores) // 2 return full_scores[:offline_iters], full_scores[offline_iters:], regret diff --git a/results/get_finetune_tables_and_plots.py b/results/get_finetune_tables_and_plots.py index d8072652..997bc15a 100644 --- a/results/get_finetune_tables_and_plots.py +++ b/results/get_finetune_tables_and_plots.py @@ -76,6 +76,7 @@ def get_last_scores(avg_scores, avg_stds): for algo in full_offline_scores: for data in full_offline_scores[algo]: + # print(algo, flush=True) full_offline_scores[algo][data] = [s[0] for s in full_scores[algo][data]] full_online_scores[algo][data] = [s[1] for s in full_scores[algo][data]] regrets[algo][data] = np.mean([s[2] for s in full_scores[algo][data]]) @@ -127,7 +128,7 @@ def add_domains_avg(scores): add_domains_avg(last_online_scores) add_domains_avg(regrets) -algorithms = ["AWAC", "CQL", "IQL", "SPOT", "Cal-QL"] +algorithms = ["AWAC", "CQL", "IQL", "SPOT", "Cal-QL", "ReBRAC",] datasets = dataframe["dataset"].unique() ordered_datasets = [ "antmaze-umaze-v2", @@ -417,8 +418,9 @@ def flatten(data): "CQL", "IQL", "AWAC", + "Cal-QL", ] -for a1 in ["Cal-QL"]: +for a1 in ["ReBRAC"]: for a2 in algs: algorithm_pairs[f"{a1},{a2}"] = (flat[a1], flat[a2]) average_probabilities, average_prob_cis = rly.get_interval_estimates( diff --git a/results/get_finetune_urls.py b/results/get_finetune_urls.py index 12d7e374..4059dee7 100644 --- a/results/get_finetune_urls.py +++ b/results/get_finetune_urls.py @@ -1,4 +1,5 @@ import pandas as pd + import wandb collected_urls = { @@ -18,6 +19,8 @@ def get_urls(sweep_id, algo_name): dataset = run.config["env"] elif "env_name" in run.config: dataset = run.config["env_name"] + elif "dataset_name" in run.config: + dataset = run.config["dataset_name"] name = algo_name if "10" in "-".join(run.name.split("-")[:-1]): name = "10% " + name @@ -41,6 +44,8 @@ def get_urls(sweep_id, algo_name): get_urls("tlab/CORL/sweeps/efvz7d68", "Cal-QL") +get_urls("tlab/CORL/sweeps/62cgqb8c", "ReBRAC") + dataframe = pd.DataFrame(collected_urls) dataframe.to_csv("runs_tables/finetune_urls.csv", index=False) diff --git a/results/runs_tables/finetune_urls.csv b/results/runs_tables/finetune_urls.csv index 3058d67a..ae637030 100644 --- a/results/runs_tables/finetune_urls.csv +++ b/results/runs_tables/finetune_urls.csv @@ -32,8 +32,8 @@ SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/0kn4pl04 SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/onon4kbo SPOT,antmaze-large-play-v2,tlab/CORL/runs/5ldclyhi SPOT,antmaze-large-play-v2,tlab/CORL/runs/v8uskc0k -SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/sysutdr0 SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/tnksp757 +SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/sysutdr0 SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/cg7vkg7p SPOT,antmaze-medium-diverse-v2,tlab/CORL/runs/12ivynlo SPOT,antmaze-large-diverse-v2,tlab/CORL/runs/kvxnj3cw @@ -75,20 +75,20 @@ AWAC,antmaze-large-play-v2,tlab/CORL/runs/z88yrg2k AWAC,antmaze-large-play-v2,tlab/CORL/runs/b1e1up19 AWAC,antmaze-large-play-v2,tlab/CORL/runs/bpq150gg AWAC,antmaze-large-play-v2,tlab/CORL/runs/rffmurq2 -AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/tiq65215 AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/uo63dgzx +AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/tiq65215 AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/mfjh57xv AWAC,antmaze-medium-diverse-v2,tlab/CORL/runs/325fy2js CQL,antmaze-umaze-v2,tlab/CORL/runs/vdh2wmw9 CQL,antmaze-umaze-v2,tlab/CORL/runs/wh27aupq CQL,antmaze-umaze-v2,tlab/CORL/runs/7r4uwutz -CQL,antmaze-large-play-v2,tlab/CORL/runs/kt7jwqcz CQL,antmaze-umaze-v2,tlab/CORL/runs/l5xvgwt4 +CQL,antmaze-large-play-v2,tlab/CORL/runs/kt7jwqcz CQL,antmaze-large-play-v2,tlab/CORL/runs/8fm40vpm CQL,antmaze-large-play-v2,tlab/CORL/runs/yeax28su -CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/gvhslqyo -CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/mowkqr6u CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/pswhm9pi +CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/mowkqr6u +CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/gvhslqyo CQL,antmaze-medium-diverse-v2,tlab/CORL/runs/hh5vv5qc CQL,door-cloned-v1,tlab/CORL/runs/2xc0y5sd CQL,door-cloned-v1,tlab/CORL/runs/2ylul8yr @@ -188,14 +188,54 @@ Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/a015fjb1 Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/1pu06s2i Cal-QL,antmaze-umaze-diverse-v2,tlab/CORL/runs/iwa1o31k Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/yvqv3mxa -Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/4myjeu5g Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/6ptdr78l +Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/4myjeu5g Cal-QL,antmaze-large-diverse-v2,tlab/CORL/runs/8ix0469p -Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/4chdwkua Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/fzrlcnwp +Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/4chdwkua Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/f9hz4fal Cal-QL,antmaze-large-play-v2,tlab/CORL/runs/fpq2ob8q Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/zhf7tr7p Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/m02ew5oy Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/9r1a0trx Cal-QL,antmaze-medium-diverse-v2,tlab/CORL/runs/ds2dbx2u +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/jt59ttpm +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/i8v1hu1q +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/qedc68y9 +ReBRAC,antmaze-umaze-v2,tlab/CORL/runs/69ax7dcn +ReBRAC,door-cloned-v1,tlab/CORL/runs/eqfcelaj +ReBRAC,door-cloned-v1,tlab/CORL/runs/prxbzf7x +ReBRAC,door-cloned-v1,tlab/CORL/runs/tdcevdlp +ReBRAC,door-cloned-v1,tlab/CORL/runs/kl86pkrh +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/gpzs7ko9 +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/mpqjq5iv +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/5fu56nvb +ReBRAC,hammer-cloned-v1,tlab/CORL/runs/dfzw050e +ReBRAC,pen-cloned-v1,tlab/CORL/runs/j29tzijm +ReBRAC,pen-cloned-v1,tlab/CORL/runs/4en0ayoi +ReBRAC,pen-cloned-v1,tlab/CORL/runs/4eyhtq9d +ReBRAC,pen-cloned-v1,tlab/CORL/runs/so5u47bk +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/yskdxtsn +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/32gy3ulp +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/l4jqb7k7 +ReBRAC,relocate-cloned-v1,tlab/CORL/runs/mgd0l6wd +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/owmozman +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/0w7wg74n +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/arp17u1o +ReBRAC,antmaze-medium-play-v2,tlab/CORL/runs/erip8h20 +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/5pd0cva3 +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/2beklksm +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/68m94c0v +ReBRAC,antmaze-umaze-diverse-v2,tlab/CORL/runs/g5cshcb2 +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/qyggi592 +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/4u324o5s +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/vbgrhecg +ReBRAC,antmaze-large-diverse-v2,tlab/CORL/runs/2286tpe9 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/24zf9wms +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/wc7z2xz3 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/wxq1x3o2 +ReBRAC,antmaze-large-play-v2,tlab/CORL/runs/p8j2dn0y +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/0ig2xzi7 +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/puzizafo +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/m84hfj5d +ReBRAC,antmaze-medium-diverse-v2,tlab/CORL/runs/u1y0nldu From d60ad7b035c21a4fc20a5dcdb1992a91dd795c10 Mon Sep 17 00:00:00 2001 From: Denis Tarasov Date: Fri, 15 Sep 2023 12:21:59 +0200 Subject: [PATCH 09/15] Change arrows in README --- README.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 917b6a60..8e823a5a 100644 --- a/README.md +++ b/README.md @@ -171,21 +171,21 @@ You can check the links above for learning curves and details. Here, we report r #### Scores | **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| |---------------------------|------------|--------|--------|-----|-----|-----| -|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 -> 74.75 ± 42.59| -|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 -> 98.00 ± 2.92| -|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 -> 98.00 ± 1.58| -|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 -> 98.75 ± 0.43| -|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 -> 31.50 ± 33.56| -|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 -> 72.25 ± 41.73| +|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 →gi 74.75 ± 42.59| +|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 →gi 98.00 ± 2.92| +|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 →gi 98.00 ± 1.58| +|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 →gi 98.75 ± 0.43| +|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 →gi 31.50 ± 33.56| +|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 →gi 72.25 ± 41.73| | | | | | | | | | | | -| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 -> 78.88| +| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 →gi 78.88| | | | | | | | | | | | -|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 -> 138.15 ± 3.22| -|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 -> 102.39 ± 8.27| -|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 -> 124.65 ± 7.37| -|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 -> 6.96 ± 4.59| +|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 →gi 138.15 ± 3.22| +|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 →gi 102.39 ± 8.27| +|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 →gi 124.65 ± 7.37| +|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 →gi 6.96 ± 4.59| | | | | | | | | | | | -| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79|20.33 -> 93.04| +| **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79|20.33 → 93.04| #### Regrets | **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| From d44a99502c975c8a67929bc1130ae6d90c832993 Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:46:34 +0200 Subject: [PATCH 10/15] Update get_finetune_scores.py --- results/get_finetune_scores.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/results/get_finetune_scores.py b/results/get_finetune_scores.py index 45b379a8..9201526d 100644 --- a/results/get_finetune_scores.py +++ b/results/get_finetune_scores.py @@ -2,9 +2,9 @@ import pickle import pandas as pd -from tqdm import tqdm import wandb +from tqdm import tqdm dataframe = pd.read_csv("runs_tables/finetune_urls.csv") From 7f002ec7c5a9ae4d34056c3eede23b4ff0c21f45 Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:47:23 +0200 Subject: [PATCH 11/15] Update rebrac.py --- algorithms/finetune/rebrac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py index f5311b21..e9eac6b8 100644 --- a/algorithms/finetune/rebrac.py +++ b/algorithms/finetune/rebrac.py @@ -20,11 +20,11 @@ import optax import pyrallis import tqdm +import wandb from flax.core import FrozenDict from flax.training.train_state import TrainState from tqdm.auto import trange -import wandb ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") From 0a185d3b6929651462f31b42e9f27bc0538f62bd Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:47:47 +0200 Subject: [PATCH 12/15] Update get_finetune_urls.py --- results/get_finetune_urls.py | 1 - 1 file changed, 1 deletion(-) diff --git a/results/get_finetune_urls.py b/results/get_finetune_urls.py index 4059dee7..07048ca3 100644 --- a/results/get_finetune_urls.py +++ b/results/get_finetune_urls.py @@ -1,5 +1,4 @@ import pandas as pd - import wandb collected_urls = { From 70aacd796f96c28b91f1ba39fa2788099aa2f96e Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:53:02 +0200 Subject: [PATCH 13/15] Update get_finetune_scores.py --- results/get_finetune_scores.py | 1 - 1 file changed, 1 deletion(-) diff --git a/results/get_finetune_scores.py b/results/get_finetune_scores.py index 9201526d..1fb9303d 100644 --- a/results/get_finetune_scores.py +++ b/results/get_finetune_scores.py @@ -2,7 +2,6 @@ import pickle import pandas as pd - import wandb from tqdm import tqdm From d55af0e983095d9020f71b6593215e1cb0c6d25e Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:53:23 +0200 Subject: [PATCH 14/15] Update rebrac.py --- algorithms/finetune/rebrac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/algorithms/finetune/rebrac.py b/algorithms/finetune/rebrac.py index e9eac6b8..26b0d661 100644 --- a/algorithms/finetune/rebrac.py +++ b/algorithms/finetune/rebrac.py @@ -25,7 +25,6 @@ from flax.training.train_state import TrainState from tqdm.auto import trange - ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate") default_kernel_init = nn.initializers.lecun_normal() From c8a989b682bc8b6403f73849810591572a33d5d5 Mon Sep 17 00:00:00 2001 From: Denis Tarasov <39963896+DT6A@users.noreply.github.com> Date: Wed, 6 Dec 2023 23:39:15 +0100 Subject: [PATCH 15/15] Update README.md --- README.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 3813396a..88b8a19a 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,7 @@ docker run --gpus=all -it --rm --name | ✅ [Conservative Q-Learning for Offline Reinforcement Learning
(CQL)](https://arxiv.org/abs/2006.04779) | [`offline/cql.py`](algorithms/offline/cql.py)
[`finetune/cql.py`](algorithms/finetune/cql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-CQL--Vmlldzo1MzM4MjY3)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-CQL--Vmlldzo0NTQ3NTMz) | ✅ [Accelerating Online Reinforcement Learning with Offline Datasets
(AWAC)](https://arxiv.org/abs/2006.09359) | [`offline/awac.py`](algorithms/offline/awac.py)
[`finetune/awac.py`](algorithms/finetune/awac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-AWAC--Vmlldzo1MzM4MTEy)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-AWAC--VmlldzozODAyNzQz) | ✅ [Offline Reinforcement Learning with Implicit Q-Learning
(IQL)](https://arxiv.org/abs/2110.06169) | [`offline/iql.py`](algorithms/offline/iql.py)
[`finetune/iql.py`](algorithms/finetune/iql.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-IQL--Vmlldzo1MzM4MzQz)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-IQL--VmlldzozNzE1MTEy) +| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py)
[`finetune/rebrac.py`](algorithms/finetune/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2)

[`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-ReBRAC--Vmlldzo1NDAyNjE5) | **Offline-to-Online only** | | | ✅ [Supported Policy Optimization for Offline Reinforcement Learning
(SPOT)](https://arxiv.org/abs/2202.06239) | [`finetune/spot.py`](algorithms/finetune/spot.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-SPOT--VmlldzozODk5MTgx) | ✅ [Cal-QL: Calibrated Offline RL Pre-Training for Efficient Online Fine-Tuning
(Cal-QL)](https://arxiv.org/abs/2303.05479) | [`finetune/cal_ql.py`](algorithms/finetune/cal_ql.py) | [`Offline-to-online`](https://wandb.ai/tlab/CORL/reports/-Offline-to-Online-Cal-QL--Vmlldzo0NTQ3NDk5) @@ -57,7 +58,6 @@ docker run --gpus=all -it --rm --name | ✅ [Decision Transformer: Reinforcement Learning via Sequence Modeling
(DT)](https://arxiv.org/abs/2106.01345) | [`offline/dt.py`](algorithms/offline/dt.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-Decision-Transformer--Vmlldzo1MzM3OTkx) | ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(SAC-N)](https://arxiv.org/abs/2110.01548) | [`offline/sac_n.py`](algorithms/offline/sac_n.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-SAC-N--VmlldzoyNzA1NTY1) | ✅ [Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble
(EDAC)](https://arxiv.org/abs/2110.01548) | [`offline/edac.py`](algorithms/offline/edac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-EDAC--VmlldzoyNzA5ODUw) -| ✅ [Revisiting the Minimalist Approach to Offline Reinforcement Learning
(ReBRAC)](https://arxiv.org/abs/2305.09836) | [`offline/rebrac.py`](algorithms/offline/rebrac.py) | [`Offline`](https://wandb.ai/tlab/CORL/reports/-Offline-ReBRAC--Vmlldzo0ODkzOTQ2) | ✅ [Q-Ensemble for Offline RL: Don't Scale the Ensemble, Scale the Batch Size
(LB-SAC)](https://arxiv.org/abs/2211.11092) | [`offline/lb_sac.py`](algorithms/offline/lb_sac.py) | [`Offline Gym-MuJoCo`](https://wandb.ai/tlab/CORL/reports/LB-SAC-D4RL-Results--VmlldzozNjIxMDY1) @@ -181,19 +181,19 @@ You can check the links above for learning curves and details. Here, we report r #### Scores | **Task-Name** |AWAC|CQL|IQL|SPOT|Cal-QL|ReBRAC| |---------------------------|------------|--------|--------|-----|-----|-----| -|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 →gi 74.75 ± 42.59| -|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 →gi 98.00 ± 2.92| -|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 →gi 98.00 ± 1.58| -|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 →gi 98.75 ± 0.43| -|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 →gi 31.50 ± 33.56| -|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 →gi 72.25 ± 41.73| +|antmaze-umaze-v2|52.75 ± 8.67 → 98.75 ± 1.09|94.00 ± 1.58 → 99.50 ± 0.87|77.00 ± 0.71 → 96.50 ± 1.12|91.00 ± 2.55 → 99.50 ± 0.50|76.75 ± 7.53 → 99.75 ± 0.43|98.00 ± 1.58 → 74.75 ± 42.59| +|antmaze-umaze-diverse-v2|56.00 ± 2.74 → 0.00 ± 0.00|9.50 ± 9.91 → 99.00 ± 1.22|59.50 ± 9.55 → 63.75 ± 25.02|36.25 ± 2.17 → 95.00 ± 3.67|32.00 ± 27.79 → 98.50 ± 1.12|73.75 ± 13.27 → 98.00 ± 2.92| +|antmaze-medium-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|59.00 ± 11.18 → 97.75 ± 1.30|71.75 ± 2.95 → 89.75 ± 1.09|67.25 ± 10.47 → 97.25 ± 1.30|71.75 ± 3.27 → 98.75 ± 1.64|87.50 ± 3.77 → 98.00 ± 1.58| +|antmaze-medium-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|63.50 ± 6.84 → 97.25 ± 1.92|64.25 ± 1.92 → 92.25 ± 2.86|73.75 ± 7.29 → 94.50 ± 1.66|62.00 ± 4.30 → 98.25 ± 1.48|85.25 ± 2.17 → 98.75 ± 0.43| +|antmaze-large-play-v2|0.00 ± 0.00 → 0.00 ± 0.00|28.75 ± 7.76 → 88.25 ± 2.28|38.50 ± 8.73 → 64.50 ± 17.04|31.50 ± 12.58 → 87.00 ± 3.24|31.75 ± 8.87 → 97.25 ± 1.79|68.50 ± 6.18 → 31.50 ± 33.56| +|antmaze-large-diverse-v2|0.00 ± 0.00 → 0.00 ± 0.00|35.50 ± 3.64 → 91.75 ± 3.96|26.75 ± 3.77 → 64.25 ± 4.15|17.50 ± 7.26 → 81.00 ± 14.14|44.00 ± 8.69 → 91.50 ± 3.91|67.00 ± 10.61 → 72.25 ± 41.73| | | | | | | | | | | | -| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 →gi 78.88| +| **antmaze average** |18.12 → 16.46|48.38 → 95.58|56.29 → 78.50|52.88 → 92.38|53.04 → 97.33|80.00 → 78.88| | | | | | | | | | | | -|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 →gi 138.15 ± 3.22| -|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 →gi 102.39 ± 8.27| -|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 →gi 124.65 ± 7.37| -|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 →gi 6.96 ± 4.59| +|pen-cloned-v1|88.66 ± 15.10 → 86.82 ± 11.12|-2.76 ± 0.08 → -1.28 ± 2.16|84.19 ± 3.96 → 102.02 ± 20.75|6.19 ± 5.21 → 43.63 ± 20.09|-2.66 ± 0.04 → -2.68 ± 0.12|74.04 ± 11.97 → 138.15 ± 3.22| +|door-cloned-v1|0.93 ± 1.66 → 0.01 ± 0.00|-0.33 ± 0.01 → -0.33 ± 0.01|1.19 ± 0.93 → 20.34 ± 9.32|-0.21 ± 0.14 → 0.02 ± 0.31|-0.33 ± 0.01 → -0.33 ± 0.01|0.07 ± 0.04 → 102.39 ± 8.27| +|hammer-cloned-v1|1.80 ± 3.01 → 0.24 ± 0.04|0.56 ± 0.55 → 2.85 ± 4.81|1.35 ± 0.32 → 57.27 ± 28.49|3.97 ± 6.39 → 3.73 ± 4.99|0.25 ± 0.04 → 0.17 ± 0.17|6.54 ± 3.35 → 124.65 ± 7.37| +|relocate-cloned-v1|-0.04 ± 0.04 → -0.04 ± 0.01|-0.33 ± 0.01 → -0.33 ± 0.01|0.04 ± 0.04 → 0.32 ± 0.38|-0.24 ± 0.01 → -0.15 ± 0.05|-0.31 ± 0.05 → -0.31 ± 0.04|0.70 ± 0.62 → 6.96 ± 4.59| | | | | | | | | | | | | **adroit average** |22.84 → 21.76|-0.72 → 0.22|21.69 → 44.99|2.43 → 11.81|-0.76 → -0.79|20.33 → 93.04|