diff --git a/experiments/data_provider.py b/experiments/data_provider.py index 621d9351..b00215ff 100644 --- a/experiments/data_provider.py +++ b/experiments/data_provider.py @@ -198,9 +198,10 @@ def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points: num_train_traj = 8 recordings_dir = [os.path.join(DATA_DIR, 'recordings_rc_car_v1')] elif car_id == 2: - num_train_traj = 10 + num_train_traj = 12 recordings_dir = [os.path.join(DATA_DIR, 'recordings_rc_car_v2'), - os.path.join(DATA_DIR, 'recordings_rc_car_v3')] + os.path.join(DATA_DIR, 'recordings_rc_car_v3'), + os.path.join(DATA_DIR, 'recordings_rc_car_v4')] else: raise ValueError(f"Unknown car id {car_id}") files = [sorted(glob.glob(rd + '/*.pickle')) for rd in recordings_dir] @@ -210,14 +211,20 @@ def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points: # load and shuffle transitions transitions = _load_transitions(file_names) - indices = jax.random.permutation(key=jax.random.PRNGKey(9345), x=jnp.arange(0, len(transitions))) - transitions = [transitions[idx] for idx in indices] + # indices = jax.random.permutation(key=jax.random.PRNGKey(9345), x=jnp.arange(0, len(transitions))) + # transitions = [transitions[idx] for idx in indices] # transform transitions into supervised learning datasets prep_fn = partial(_rccar_transitions_to_dataset, encode_angles=encode_angle, skip_first_n=skip_first_n_points, action_delay=action_delay, action_stacking=action_stacking) - x_train, y_train = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[:num_train_traj]))) - x_test, y_test = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[num_train_traj:]))) + x, y = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions))) + # x_test, y_test = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[num_train_traj:]))) + indices = jnp.arange(start=0, stop=x.shape[0], step=1) + indices = jax.random.shuffle(key=jax.random.PRNGKey(9345), x=indices) + x, y = x[indices], y[indices] + num_test_points = 20_000 + x_train, y_train, x_test, y_test = x[:-num_test_points], y[:-num_test_points], \ + x[-num_test_points:], y[-num_test_points:] return x_train, y_train, x_test, y_test diff --git a/experiments/offline_rl_from_recorded_data/exp.py b/experiments/offline_rl_from_recorded_data/exp.py index 6e98b4af..85c0fa61 100644 --- a/experiments/offline_rl_from_recorded_data/exp.py +++ b/experiments/offline_rl_from_recorded_data/exp.py @@ -217,7 +217,7 @@ def experiment(horizon_len: int, if high_fidelity: outputscales_racecar = [0.008, 0.008, 0.009, 0.009, 0.05, 0.05, 0.20] else: - outputscales_racecar = [0.008, 0.008, 0.01, 0.01, 0.08, 0.08, 0.5] + outputscales_racecar = [0.008, 0.008, 0.01, 0.01, 0.1, 0.1, 0.5] sim = AdditiveSim(base_sims=[sim, GaussianProcessSim(sim.input_size, sim.output_size, output_scale=outputscales_racecar, diff --git a/experiments/offline_rl_from_recorded_data/launcher.py b/experiments/offline_rl_from_recorded_data/launcher.py index 56ed1dad..b9a55982 100644 --- a/experiments/offline_rl_from_recorded_data/launcher.py +++ b/experiments/offline_rl_from_recorded_data/launcher.py @@ -1,7 +1,7 @@ import exp from experiments.util import generate_run_commands, generate_base_command, dict_permutations -PROJECT_NAME = 'OfflineRLRunsSimConsecutive' +PROJECT_NAME = 'OfflineRLRunsGreyHW' _applicable_configs = { 'horizon_len': [200], @@ -9,8 +9,8 @@ 'project_name': [PROJECT_NAME], 'sac_num_env_steps': [2_000_000], 'num_epochs': [50], - 'max_train_steps': [100_000], - 'min_train_steps': [40_000], # for HW 30_000 worked the best. + 'max_train_steps': [40_000], + 'min_train_steps': [40_000], 'learnable_likelihood_std': ['yes'], 'include_aleatoric_noise': [1], 'best_bnn_model': [1], @@ -18,35 +18,39 @@ 'margin_factor': [20.0], 'ctrl_cost_weight': [0.005], 'ctrl_diff_weight': [0.0], - 'num_offline_collected_transitions': [20, 50, 100, 200, 400, 800, 1600, 2000, 2500, 10_000, 20_000], + 'num_offline_collected_transitions': [20, 50, 100, 200, 400, 800, 1600, 2000, 2500, 5_000, 10_000, 20_000], 'test_data_ratio': [0.0], 'eval_on_all_offline_data': [1], 'eval_only_on_init_states': [1], 'share_of_x0s_in_sac_buffer': [0.5], - 'bnn_batch_size': [32], # for HW 8 worked the best + 'bnn_batch_size': [32], 'likelihood_exponent': [0.5], 'train_sac_only_from_init_states': [0], - 'data_from_simulation': [1], + 'data_from_simulation': [0], 'num_frame_stack': [3], 'bandwidth_svgd': [0.2], - 'length_scale_aditive_sim_gp': [10.0], + 'length_scale_aditive_sim_gp': [5.0], 'input_from_recorded_data': [1], - 'obtain_consecutive_data': [0, 1], + 'obtain_consecutive_data': [1], + 'lr': [3e-4], } _applicable_configs_no_sim_prior = {'use_sim_prior': [0], 'use_grey_box': [0], + 'use_sim_model': [0], 'high_fidelity': [0], 'predict_difference': [1], 'num_measurement_points': [8] } | _applicable_configs _applicable_configs_high_fidelity = {'use_sim_prior': [1], 'use_grey_box': [0], + 'use_sim_model': [0], 'high_fidelity': [1], 'predict_difference': [1], 'num_measurement_points': [8]} | _applicable_configs _applicable_configs_low_fidelity = {'use_sim_prior': [1], 'use_grey_box': [0], + 'use_sim_model': [0], 'high_fidelity': [0], 'predict_difference': [1], 'num_measurement_points': [8]} | _applicable_configs @@ -54,26 +58,46 @@ _applicable_configs_grey_box_low_fidelity = {'use_sim_prior': [0], 'high_fidelity': [0], 'use_grey_box': [1], - 'predict_difference': [0], + 'use_sim_model': [0], + 'predict_difference': [1], 'num_measurement_points': [8]} | _applicable_configs _applicable_configs_grey_box_high_fidelity = {'use_sim_prior': [0], 'high_fidelity': [1], 'use_grey_box': [1], - 'predict_difference': [0], + 'use_sim_model': [0], + 'predict_difference': [1], + 'num_measurement_points': [8]} | _applicable_configs + +_applicable_configs_sim_model_high_fidelity = {'use_sim_prior': [0], + 'high_fidelity': [1], + 'use_grey_box': [0], + 'use_sim_model': [1], + 'predict_difference': [1], + 'num_measurement_points': [8]} | _applicable_configs + +_applicable_configs_sim_model_low_fidelity = {'use_sim_prior': [0], + 'high_fidelity': [0], + 'use_grey_box': [0], + 'use_sim_model': [1], + 'predict_difference': [1], 'num_measurement_points': [8]} | _applicable_configs # all_flags_combinations = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations( # _applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + dict_permutations( # _applicable_configs_grey_box) -all_flags_combinations = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations( - _applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) # + dict_permutations( +sim_flags = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations( + _applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + \ + dict_permutations(_applicable_configs_grey_box_low_fidelity) + \ + dict_permutations(_applicable_configs_sim_model_low_fidelity) -# _applicable_configs_grey_box) +hw_flags = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations( + _applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + \ + dict_permutations(_applicable_configs_grey_box_high_fidelity) + \ + dict_permutations(_applicable_configs_sim_model_high_fidelity) -all_flags_combinations += dict_permutations(_applicable_configs_grey_box_low_fidelity) + dict_permutations( - _applicable_configs_grey_box_high_fidelity) +all_flags_combinations = sim_flags def main(): diff --git a/experiments/online_rl_hardware/launcher.py b/experiments/online_rl_hardware/launcher.py index 6d7a7f9f..028c9a20 100644 --- a/experiments/online_rl_hardware/launcher.py +++ b/experiments/online_rl_hardware/launcher.py @@ -1,14 +1,23 @@ import online_rl_loop from experiments.util import generate_run_commands, generate_base_command, dict_permutations + def main(args): _applicable_configs = { - 'prior': ['none_FVSGD', 'none_SVGD', 'high_fidelity', 'low_fidelity'], # 'high_fidelity_no_aditive_GP'], + 'prior': ['none_FVSGD', 'high_fidelity', 'low_fidelity', + 'low_fidelity_grey_box'], 'seed': list(range(5)), - 'run_remote': [0], + 'machine': ['local'], 'gpu': [1], - 'wandb_tag': ['gpu' if args.num_gpus > 0 else 'cpu'], - 'project_name': ['OnlineRLDebug3'], + 'project_name': ['OnlineRLTestFull'], + 'reset_bnn': [1], + 'deterministic_policy': [1], + 'initial_state_fraction': [0.5], + 'bnn_train_steps': [40_000], + 'sac_num_env_steps': [500_000], + 'num_sac_envs': [128], + 'num_env_steps': [100], + 'num_f_samples': [512] } all_flags_combinations = dict_permutations(_applicable_configs) @@ -25,8 +34,9 @@ def main(args): if __name__ == '__main__': import argparse + parser = argparse.ArgumentParser(description='Meta-BO run') - parser.add_argument('--num_cpus', type=int, default=2) + parser.add_argument('--num_cpus', type=int, default=1) parser.add_argument('--num_gpus', type=int, default=1) args = parser.parse_args() main(args) diff --git a/experiments/online_rl_hardware/online_rl_loop.py b/experiments/online_rl_hardware/online_rl_loop.py index fdc55ce7..1120c2da 100644 --- a/experiments/online_rl_hardware/online_rl_loop.py +++ b/experiments/online_rl_hardware/online_rl_loop.py @@ -1,7 +1,6 @@ import json import os import pickle -import random import sys from pprint import pprint from typing import Any, NamedTuple @@ -16,21 +15,22 @@ from experiments.online_rl_hardware.train_policy import ModelBasedRLConfig from experiments.online_rl_hardware.train_policy import train_model_based_policy from experiments.online_rl_hardware.utils import (set_up_bnn_dynamics_model, set_up_dummy_sac_trainer, - dump_trajectory_summary, execute) + dump_trajectory_summary, execute, + prepare_init_transitions_for_car_env, get_random_hash) from experiments.util import Logger, RESULT_DIR from sim_transfer.sims.envs import RCCarSimEnv from sim_transfer.sims.util import plot_rc_trajectory - -WANDB_ENTITY = 'jonasrothfuss' -EULER_ENTITY = 'rojonas' -WANDB_LOG_DIR_EULER = '/cluster/scratch/' + EULER_ENTITY PRIORS = {'none_FVSGD', 'none_SVGD', 'high_fidelity', 'low_fidelity', 'high_fidelity_no_aditive_GP', + 'high_fidelity_grey_box', + 'low_fidelity_grey_box', + 'high_fidelity_sim', + 'low_fidelity_sim' } @@ -38,26 +38,28 @@ def _load_remote_config(machine: str): # load remote config with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'remote_config.json'), 'r') as f: - remote_config = json.load(f) + config = json.load(f) # choose machine - assert machine in remote_config, f'Machine {machine} not found in remote config. ' \ - f'Available machines: {list(remote_config.keys())}' - remote_config = remote_config[machine] + assert machine in config['remote_machines'], \ + f'Machine {machine} not found in remote config. Available machines: {list(config["remote_machines"].keys())}' + remote_config = config['remote_machines'][machine] + + assert 'user_config' in config, 'No user config found in remote config.' + user_config = config['user_config'] + user_config['wandb_log_dir_euler'] = '/cluster/scratch/' + user_config['euler_entity'] - # create local director if it does not exist + # create local directory if it does not exist os.makedirs(remote_config['local_dir'], exist_ok=True) # print remote config print(f'Remote config [{machine}]:') pprint(remote_config) - print('') - - return remote_config + print('\nUser config:') + pprint(user_config) -def _get_random_hash() -> str: - return "%032x" % random.getrandbits(128) + return remote_config, user_config def train_model_based_policy_remote(*args, @@ -82,14 +84,14 @@ def train_model_based_policy_remote(*args, if machine == 'local': # if not running remotely, just run the function locally and return the result return train_model_based_policy(*args, **kwargs) - rmt_cfg = _load_remote_config(machine=machine) + rmt_cfg, _ = _load_remote_config(machine=machine) # copy latest version of train_policy.py to remote and make sure remote directory exists execute(f'scp {rmt_cfg["local_script"]} {rmt_cfg["remote_machine"]}:{rmt_cfg["remote_script"]}', verbosity) execute(f'ssh {rmt_cfg["remote_machine"]} "mkdir -p {rmt_cfg["remote_dir"]}"', verbosity) # dump train_data to local pkl file - run_hash = _get_random_hash() + run_hash = get_random_hash() train_data_path_local = os.path.join(rmt_cfg['local_dir'], f'train_data_{run_hash}.pkl') with open(train_data_path_local, 'wb') as f: pickle.dump({'args': args, 'kwargs': kwargs}, f) @@ -104,12 +106,11 @@ def train_model_based_policy_remote(*args, # run the train_policy.py script on the remote machine result_path_remote = os.path.join(rmt_cfg['remote_dir'], f'result_{run_hash}.pkl') - command = f'export PYTHONPATH={rmt_cfg["remote_pythonpath"]} && ' \ - f'{rmt_cfg["remote_interpreter"]} {rmt_cfg["remote_script"]} ' \ + command = f'{rmt_cfg["remote_interpreter"]} {rmt_cfg["remote_script"]} ' \ f'--data_load_path {train_data_path_remote} --model_dump_path {result_path_remote}' if verbosity: print('[Local] Executing command:', command) - execute(f'ssh {rmt_cfg["remote_machine"]} "{command}"', verbosity) + execute(f'ssh -tt {rmt_cfg["remote_machine"]} "{rmt_cfg["remote_pre_cmd"]} {command}"', verbosity) # transfer result back to local result_path_local = os.path.join(rmt_cfg['local_dir'], f'result_{run_hash}.pkl') @@ -139,6 +140,7 @@ class MainConfig(NamedTuple): include_aleatoric_noise: int = 1 best_bnn_model: int = 1 best_policy: int = 1 + deterministic_policy: int = 1 predict_difference: int = 1 margin_factor: float = 20.0 ctrl_cost_weight: float = 0.005 @@ -148,22 +150,47 @@ class MainConfig(NamedTuple): likelihood_exponent: float = 0.5 data_batch_size: int = 32 bandwidth_svgd: float = 0.2 - length_scale_aditive_sim_gp: float = 10.0 + length_scale_aditive_sim_gp: float = 5.0 num_f_samples: int = 512 num_measurement_points: int = 16 + initial_state_fraction: float = 0.5 + sim: int = 1 + control_time_ms: float = 24. + num_sac_envs: int = 64 + eval_only_on_init_states: int = 1 def main(config: MainConfig = MainConfig(), encode_angle: bool = True, machine: str = 'local'): rng_key_env, rng_key_model, rng_key_rollouts = jax.random.split(jax.random.PRNGKey(config.seed), 3) - env = RCCarSimEnv(encode_angle=encode_angle, - action_delay=config.delay, - use_tire_model=True, - use_obs_noise=True, - ctrl_cost_weight=config.ctrl_cost_weight, - margin_factor=config.margin_factor, - ) + _, user_cfg = _load_remote_config(machine=machine) + + """Setup car reward kwargs""" + car_reward_kwargs = dict(encode_angle=encode_angle, + ctrl_cost_weight=config.ctrl_cost_weight, + margin_factor=config.margin_factor) + """Set up env""" + if bool(config.sim): + env = RCCarSimEnv(encode_angle=encode_angle, + action_delay=config.delay, + use_tire_model=True, + use_obs_noise=True, + ctrl_cost_weight=config.ctrl_cost_weight, + margin_factor=config.margin_factor, + ) + else: + from sim_transfer.hardware.car_env import CarEnv + # We do not perform frame stacking in the env and do it manually here in the rollout function. + env = CarEnv( + encode_angle=encode_angle, + car_id=2, + control_time_ms=config.control_time_ms, + max_throttle=0.4, + car_reward_kwargs=car_reward_kwargs, + num_frame_stacks=0 + + ) # initialize train_data as empty arrays train_data = { @@ -176,16 +203,11 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, ################################################################################ """Setup key""" key = jr.PRNGKey(config.seed) - key_bnn, key_run_episodes, key_dummy_sac_trainer = jr.split(key, 3) - - """Setup car reward kwargs""" - car_reward_kwargs = dict(encode_angle=encode_angle, - ctrl_cost_weight=config.ctrl_cost_weight, - margin_factor=config.margin_factor) + key_bnn, key_run_episodes, key_dummy_sac_trainer, key = jr.split(key, 4) """Setup SAC config dict""" num_env_steps_between_updates = 16 - num_envs = 64 + num_envs = config.num_sac_envs sac_kwargs = dict(num_timesteps=config.sac_num_env_steps, num_evals=20, reward_scaling=10, @@ -217,11 +239,13 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, total_config = sac_kwargs | config._asdict() | car_reward_kwargs """ WANDB & Logging configuration """ - wandb_config = {'project': config.project_name, 'entity': WANDB_ENTITY, 'resume': 'allow', - 'dir': WANDB_LOG_DIR_EULER if os.path.isdir(WANDB_LOG_DIR_EULER) else '/tmp/', + wandb_config = {'project': config.project_name, 'entity': user_cfg['wandb_entity'], 'resume': 'allow', + 'dir': user_cfg['wandb_log_dir_euler'] if os.path.isdir(user_cfg['wandb_log_dir_euler']) \ + else '/tmp/', 'config': total_config, 'settings': {'_service_wait': 300}} wandb.init(**wandb_config) - wandb_config['id'] = wandb.run.id + run_id = wandb.run.id + wandb_config['id'] = run_id remote_training = not (machine == 'local') dump_dir = os.path.join(RESULT_DIR, 'online_rl_hardware', wandb_config['id']) @@ -229,12 +253,10 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, log_path = os.path.join(dump_dir, f"{wandb_config['id']}.log") if machine == 'euler': - wandb_config_remote = wandb_config | {'dir': '/cluster/scratch/' + EULER_ENTITY} + wandb_config_remote = wandb_config | {'dir': '/cluster/scratch/' + user_cfg['euler_entity']} else: wandb_config_remote = wandb_config | {'dir': '/tmp/'} - if remote_training: - wandb.finish() sys.stdout = Logger(log_path, stream=sys.stdout) sys.stderr = Logger(log_path, stream=sys.stderr) print(f'\nDumping trajectories and logs to {dump_dir}\n') @@ -257,29 +279,55 @@ def main(config: MainConfig = MainConfig(), encode_angle: bool = True, bnn_training_test_ratio=0.2, max_num_episodes=100) + initial_states_fraction = max(min(config.initial_state_fraction, 0.9999), 0.0) + + def init_state_points(true_buffer_points): + desired_points = int(initial_states_fraction * true_buffer_points + / (1 - initial_states_fraction)) + return min(desired_points, config.max_replay_size_true_data_buffer) + """ Set up dummy SAC trainer for getting the policy from policy params """ dummy_sac_trainer = set_up_dummy_sac_trainer(main_config=config, mbrl_config=mbrl_config, key=key) + key, key_eval_buffer = jr.split(key) + if config.eval_only_on_init_states: + eval_buffer_transitions = prepare_init_transitions_for_car_env(key=key_eval_buffer, + number_of_samples=1000, + num_frame_stack=config.num_stacked_actions) + else: + eval_buffer_transitions = None + """ Main loop over episodes """ for episode_id in range(1, config.num_episodes + 1): - if remote_training: - wandb.init(**wandb_config) sys.stdout = Logger(log_path, stream=sys.stdout) sys.stderr = Logger(log_path, stream=sys.stderr) print('\n\n------- Episode', episode_id) key, key_episode = jr.split(key) - + key_episode, key_init_buffer = jr.split(key_episode) + + num_points = train_data['x_train'].shape[0] + num_init_state_points = init_state_points(num_points) + if num_init_state_points > 0: + init_transitions = prepare_init_transitions_for_car_env(key=key_init_buffer, + number_of_samples=num_init_state_points, + num_frame_stack=config.num_stacked_actions) + else: + init_transitions = None # train model & policy + if remote_training: + wandb_config_remote['id'] = f'{run_id}_{episode_id}' policy_params, bnn = train_model_based_policy_remote( train_data=train_data, bnn_model=bnn, config=mbrl_config, key=key_episode, episode_idx=episode_id, machine=machine, wandb_config=wandb_config_remote, - remote_training=remote_training) + remote_training=remote_training, reset_buffer_transitions=init_transitions, + eval_buffer_transitions=eval_buffer_transitions) # get allable policy from policy params - def policy(x): - return dummy_sac_trainer.make_policy(policy_params, deterministic=True)(x, jr.PRNGKey(0))[0] + def policy(x, key: jr.PRNGKey = jr.PRNGKey(0)): + return dummy_sac_trainer.make_policy(policy_params, + deterministic=bool(config.deterministic_policy))(x, key)[0] # perform policy rollout on the car stacked_actions = jnp.zeros(shape=(config.num_stacked_actions * mbrl_config.u_dim,)) @@ -288,7 +336,7 @@ def policy(x): actions, rewards, pure_obs = [], [], [] for i in range(config.num_env_steps): rng_key_rollouts, rng_key_act = jr.split(rng_key_rollouts) - act = policy(obs) + act = policy(obs, rng_key_act) obs, reward, _, _ = env.step(act) rewards.append(reward) actions.append(act) @@ -300,7 +348,7 @@ def policy(x): # logging and saving trajectory, actions, rewards, pure_obs = map(lambda arr: jnp.array(arr), - [trajectory, actions, rewards, pure_obs]) + [trajectory, actions, rewards, pure_obs]) traj_summary = {'episode_id': episode_id, 'trajectory': trajectory, 'actions': actions, 'rewards': rewards, 'obs': pure_obs, 'return': jnp.sum(rewards)} @@ -318,9 +366,6 @@ def policy(x): train_data['y_train'] = jnp.concatenate([train_data['y_train'], trajectory[1:, :mbrl_config.x_dim]], axis=0) print(f'Size of train_data in episode {episode_id}:', train_data['x_train'].shape[0]) - if remote_training: - wandb.finish() - if __name__ == '__main__': import argparse @@ -328,12 +373,20 @@ def policy(x): parser = argparse.ArgumentParser(description='Meta-BO run') parser.add_argument('--seed', type=int, default=914) parser.add_argument('--project_name', type=str, default='OnlineRL_RCCar') - parser.add_argument('--machine', type=str, default='optimality') + parser.add_argument('--machine', type=str, default='euler') parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--sim', type=int, default=1) + parser.add_argument('--control_time_ms', type=float, default=24.) parser.add_argument('--prior', type=str, default='none_FVSGD') - parser.add_argument('--num_env_steps', type=int, default=200, help='number of steps in the environment per episode') + parser.add_argument('--num_env_steps', type=int, default=200) + parser.add_argument('--bnn_train_steps', type=int, default=40_000) + parser.add_argument('--sac_num_env_steps', type=int, default=1_000_000) + parser.add_argument('--num_sac_envs', type=int, default=64) parser.add_argument('--reset_bnn', type=int, default=0) + parser.add_argument('--deterministic_policy', type=int, default=1) + parser.add_argument('--num_f_samples', type=int, default=512) + parser.add_argument('--initial_state_fraction', type=float, default=0.5) args = parser.parse_args() if not args.gpu: @@ -345,5 +398,14 @@ def policy(x): seed=args.seed, project_name=args.project_name, num_env_steps=args.num_env_steps, - reset_bnn=args.reset_bnn), + reset_bnn=args.reset_bnn, + sim=args.sim, + control_time_ms=args.control_time_ms, + deterministic_policy=args.deterministic_policy, + initial_state_fraction=args.initial_state_fraction, + bnn_train_steps=args.bnn_train_steps, + sac_num_env_steps=args.sac_num_env_steps, + num_sac_envs=args.num_sac_envs, + num_f_samples=args.num_f_samples, + ), machine=args.machine) diff --git a/experiments/online_rl_hardware/remote_config.json b/experiments/online_rl_hardware/remote_config.json index e65ad63d..2c5c45cc 100644 --- a/experiments/online_rl_hardware/remote_config.json +++ b/experiments/online_rl_hardware/remote_config.json @@ -1,22 +1,26 @@ { - "euler": { - "local_dir": "/tmp/ssh_remote/", - "local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py", - "remote_machine": "rojonas@euler", - "remote_dir": "/cluster/scratch/rojonas/", - "remote_pre_cmd": "srun --gpus=1 --cpus-per-task=4 --time=01:00:00", - "remote_interpreter": "/cluster/project/infk/krause/rojonas/.venv/sim_transfer_gpu/bin/python", - "remote_script": "/cluster/project/infk/krause/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py", - "remote_pythonpath": "/cluster/project/infk/krause/rojonas/sim_transfer" - }, - "optimality": { - "local_dir": "/tmp/ssh_remote/", - "local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py", - "remote_machine": "rojonas@optimality.inf.ethz.ch", - "remote_dir": "/tmp/ssh_remote/", - "remote_pre_cmd": "", - "remote_interpreter": "/local/rojonas/miniconda3/envs/sim_transfer_gpu/bin/python", - "remote_script": "/local/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py", - "remote_pythonpath": "/local/rojonas/sim_transfer" + "remote_machines": { + "euler": { + "local_dir": "/tmp/ssh_remote/", + "local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py", + "remote_machine": "rojonas@euler", + "remote_dir": "/cluster/scratch/rojonas/", + "remote_pre_cmd": "srun --gpus=1 --gres=gpumem:10240m --cpus-per-task=4 --mem-per-cpu=8192m --time=01:00:00", + "remote_interpreter": "/cluster/project/infk/krause/rojonas/.venv/sim_transfer_gpu/bin/python", + "remote_script": "/cluster/project/infk/krause/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py" + }, + "optimality": { + "local_dir": "/tmp/ssh_remote/", + "local_script": "/Users/rojonas/Dropbox/Eigene_Dateien/ETH/02_Projects/20_sim_transfer/code_sim_transfer/experiments/online_rl_hardware/train_policy.py", + "remote_machine": "rojonas@optimality.inf.ethz.ch", + "remote_dir": "/tmp/ssh_remote/", + "remote_pre_cmd": "", + "remote_interpreter": "/local/rojonas/miniconda3/envs/sim_transfer_gpu/bin/python", + "remote_script": "/local/rojonas/sim_transfer/experiments/online_rl_hardware/train_policy.py" + } +}, + "user_config": { + "wandb_entity": "jonasrothfuss", + "euler_entity": "rojonas" } } \ No newline at end of file diff --git a/experiments/online_rl_hardware/train_policy.py b/experiments/online_rl_hardware/train_policy.py index 2908bc25..762999fb 100644 --- a/experiments/online_rl_hardware/train_policy.py +++ b/experiments/online_rl_hardware/train_policy.py @@ -2,11 +2,12 @@ import copy import time import chex +import jax import jax.numpy as jnp import jax.random as jr import numpy as np import wandb - +from brax.training.types import Transition from sim_transfer.models.abstract_model import BatchedNeuralNetworkModel from typing import Dict @@ -14,13 +15,17 @@ from experiments.online_rl_hardware.utils import (load_data, dump_model, ModelBasedRLConfig, init_transition_buffer, add_data_to_buffer, set_up_model_based_sac_trainer) + def train_model_based_policy(train_data: Dict, bnn_model: BatchedNeuralNetworkModel, key: chex.PRNGKey, episode_idx: int, config: ModelBasedRLConfig, wandb_config: Dict, - remote_training: bool = False): + remote_training: bool = False, + reset_buffer_transitions: Transition | None = None, + eval_buffer_transitions: Transition | None = None, + ): """ train_data = {'x_train': jnp.empty((0, state_dim + (1 + num_framestacks) * action_dim)), 'y_train': jnp.empty((0, state_dim))} @@ -36,8 +41,8 @@ def train_model_based_policy(train_data: Dict, """ Setup the data buffers """ key, key_buffer_init = jr.split(key, 2) - true_data_buffer, true_data_buffer_state = init_transition_buffer(config=config, key=key_buffer_init) - true_data_buffer, true_data_buffer_state = add_data_to_buffer(true_data_buffer, true_data_buffer_state, + true_data_buffer, init_buffer_state = init_transition_buffer(config=config, key=key_buffer_init) + true_data_buffer, true_data_buffer_state = add_data_to_buffer(true_data_buffer, init_buffer_state, x_data=x_all, y_data=y_all, config=config) """Train transition model""" @@ -54,12 +59,22 @@ def train_model_based_policy(train_data: Dict, # Train model if config.reset_bnn: bnn_model.reinit(rng_key=key_reinit_model) - bnn_model.fit(x_train=x_train, y_train=y_train, x_eval=x_test, y_eval=y_test, log_to_wandb=True, - keep_the_best=config.return_best_bnn, metrics_objective='eval_nll', log_period=2000) + bnn_model.fit_with_scan(x_train=x_train, y_train=y_train, x_eval=x_test, y_eval=y_test, log_to_wandb=True, + keep_the_best=config.return_best_bnn, metrics_objective='eval_nll', log_period=2000) print(f'Time fo training the transition model: {time.time() - t:.2f} seconds') """Train policy""" t = time.time() + if reset_buffer_transitions: + sac_buffer_state = true_data_buffer.insert(true_data_buffer_state, reset_buffer_transitions) + else: + sac_buffer_state = true_data_buffer_state + + if eval_buffer_transitions: + eval_buffer_state = true_data_buffer.insert(init_buffer_state, eval_buffer_transitions) + else: + eval_buffer_state = None + _sac_kwargs = config.sac_kwargs # TODO: Be careful!! if num_training_points == 0: @@ -69,8 +84,8 @@ def train_model_based_policy(train_data: Dict, key, key_sac_training, key_sac_trainer_init = jr.split(key, 3) sac_trainer = set_up_model_based_sac_trainer( - bnn_model=bnn_model, data_buffer=true_data_buffer, data_buffer_state=true_data_buffer_state, - key=key_sac_trainer_init, config=config, sac_kwargs=_sac_kwargs) + bnn_model=bnn_model, data_buffer=true_data_buffer, data_buffer_state=sac_buffer_state, + key=key_sac_trainer_init, config=config, sac_kwargs=_sac_kwargs, eval_buffer_state=eval_buffer_state) policy_params, metrics = sac_trainer.run_training(key=key_sac_training) diff --git a/experiments/online_rl_hardware/utils.py b/experiments/online_rl_hardware/utils.py index 3275b8e6..6f4fdf47 100644 --- a/experiments/online_rl_hardware/utils.py +++ b/experiments/online_rl_hardware/utils.py @@ -1,3 +1,4 @@ +import random from typing import Any, NamedTuple, Dict from brax.training.replay_buffers import UniformSamplingQueue @@ -5,10 +6,10 @@ from brax.training.replay_buffers import ReplayBuffer, ReplayBufferState from sim_transfer.rl.model_based_rl.learned_system import LearnedCarSystem -from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_FSVGD, BNN_SVGD +from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_FSVGD, BNN_SVGD, BNNGreyBox from sim_transfer.sims.simulators import AdditiveSim, PredictStateChangeWrapper, GaussianProcessSim from sim_transfer.sims.simulators import RaceCarSim, StackedActionSimWrapper - +from sim_transfer.sims.envs import RCCarSimEnv from mbpo.optimizers.policy_optimizers.sac.sac import SAC from mbpo.systems.brax_wrapper import BraxWrapper @@ -42,6 +43,7 @@ def execute(cmd: str, verbosity: int = 0) -> None: print(cmd) os.system(cmd) + def load_data(data_load_path: str) -> Any: # loads the pkl file with open(data_load_path, 'rb') as f: @@ -99,7 +101,8 @@ def add_data_to_buffer(buffer: ReplayBuffer, buffer_state: ReplayBufferState, x_ def set_up_model_based_sac_trainer(bnn_model, data_buffer, data_buffer_state, key: jax.random.PRNGKey, - config: ModelBasedRLConfig, sac_kwargs: dict = None): + config: ModelBasedRLConfig, sac_kwargs: dict = None, + eval_buffer_state: ReplayBufferState | None = None): if sac_kwargs is None: sac_kwargs = config.sac_kwargs @@ -109,14 +112,23 @@ def set_up_model_based_sac_trainer(bnn_model, data_buffer, data_buffer_state, ke num_frame_stack=config.num_stacked_actions, **config.car_reward_kwargs) + if eval_buffer_state is None: + eval_buffer_state = data_buffer_state + + key, eval_env_key = jax.random.split(key) env = BraxWrapper(system=system, sample_buffer_state=data_buffer_state, sample_buffer=data_buffer, system_params=system.init_params(key)) + eval_env = BraxWrapper(system=system, + sample_buffer_state=eval_buffer_state, + sample_buffer=data_buffer, + system_params=system.init_params(eval_env_key)) + # Here we create eval envs sac_trainer = SAC(environment=env, - eval_environment=env, + eval_environment=eval_env, eval_key_fixed=True, return_best_model=config.return_best_policy, **sac_kwargs, ) @@ -124,7 +136,8 @@ def set_up_model_based_sac_trainer(bnn_model, data_buffer, data_buffer_state, ke def set_up_bnn_dynamics_model(config: Any, key: jax.random.PRNGKey): - sim = RaceCarSim(encode_angle=True, use_blend=config.sim_prior == 'high_fidelity', car_id=2) + use_blend = 'high_fidelity' in config.sim_prior + sim = RaceCarSim(encode_angle=True, use_blend=use_blend, car_id=2) if config.num_stacked_actions > 0: sim = StackedActionSimWrapper(sim, num_stacked_actions=config.num_stacked_actions, action_size=2) if config.predict_difference: @@ -157,6 +170,28 @@ def set_up_bnn_dynamics_model(config: Any, key: jax.random.PRNGKey): **standard_params, bandwidth_svgd=1.0, ) + elif config.sim_prior == 'high_fidelity_grey_box' or config.sim_prior == 'low_fidelity_grey_box': + base_bnn = BNN_FSVGD( + **standard_params, + domain=sim.domain, + bandwidth_svgd=config.bandwidth_svgd, + ) + bnn = BNNGreyBox( + base_bnn=base_bnn, + sim=sim, + use_base_bnn=True, + ) + elif config.sim_prior == 'high_fidelity_sim' or config.sim_prior == 'low_fidelity_sim': + base_bnn = BNN_FSVGD( + **standard_params, + domain=sim.domain, + bandwidth_svgd=config.bandwidth_svgd, + ) + bnn = BNNGreyBox( + base_bnn=base_bnn, + sim=sim, + use_base_bnn=False, + ) elif config.sim_prior == 'high_fidelity_no_aditive_GP': bnn = BNN_FSVGD_SimPrior( **standard_params, @@ -207,4 +242,26 @@ def set_up_dummy_sac_trainer(main_config, mbrl_config: ModelBasedRLConfig, key: bnn_model=bnn, data_buffer=true_data_buffer, data_buffer_state=true_data_buffer_state, key=key_bnn, config=mbrl_config) - return sac_trainer \ No newline at end of file + return sac_trainer + + +def prepare_init_transitions_for_car_env(key: jax.random.PRNGKey, number_of_samples: int, num_frame_stack: int = 3): + sim = RCCarSimEnv(encode_angle=True, use_tire_model=True) + action_dim = 2 + key_init_state = jax.random.split(key, number_of_samples) + state_obs = jax.vmap(sim.reset)(rng_key=key_init_state) + framestacked_actions = jnp.zeros( + shape=(number_of_samples, num_frame_stack * action_dim)) + actions = jnp.zeros(shape=(number_of_samples, action_dim)) + rewards = jnp.zeros(shape=(number_of_samples,)) + discounts = 0.99 * jnp.ones(shape=(number_of_samples,)) + transitions = Transition(observation=jnp.concatenate([state_obs, framestacked_actions], axis=-1), + action=actions, + reward=rewards, + discount=discounts, + next_observation=jnp.concatenate([state_obs, framestacked_actions], axis=-1)) + return transitions + + +def get_random_hash() -> str: + return "%032x" % random.getrandbits(128) diff --git a/experiments/util.py b/experiments/util.py index be2be7ec..1ddf645a 100644 --- a/experiments/util.py +++ b/experiments/util.py @@ -149,7 +149,7 @@ def generate_run_commands(command_list: List[str], output_file_list: Optional[Li f'--cpus-per-task {num_cpus} ' if num_gpus > 0: - bsub_cmd += f'-G {num_gpus} --gres=gpumem:10240m' + bsub_cmd += f'-G {num_gpus} --gres=gpumem:10240m ' assert output_file_list is None or len(command_list) == len(output_file_list) if output_file_list is None: diff --git a/sim_transfer/sims/car_sim_config.py b/sim_transfer/sims/car_sim_config.py index e2bd9a3e..3ab6b34f 100644 --- a/sim_transfer/sims/car_sim_config.py +++ b/sim_transfer/sims/car_sim_config.py @@ -96,20 +96,20 @@ 'm': 1.65, 'l_f': 0.13, 'l_r': 0.17, - 'angle_offset': 0.031, + 'angle_offset': 0.00, 'b_f': 2.58, - 'b_r': 4.75, + 'b_r': 5.0, 'blend_ratio_lb': 0.01, 'blend_ratio_ub': 0.01, 'c_d': 0.0, 'c_f': 1.2, - 'c_m_1': 8.46, - 'c_m_2': 1.6, + 'c_m_1': 8.0, + 'c_m_2': 1.5, 'c_r': 1.27, 'd_f': 0.02, 'd_r': 0.017, 'i_com': 0.01, - 'steering_limit': 0.347 + 'steering_limit': 0.3 } BOUNDS_PARAMS_BICYCLE_CAR2: Dict = { @@ -117,20 +117,20 @@ 'm': (1.6, 1.7), 'l_f': (0.11, 0.15), 'l_r': (0.15, 0.19), - 'angle_offset': (0.001, 0.05), - 'b_f': (2.2, 2.8), - 'b_r': (3.0, 7.0), + 'angle_offset': (-0.15, 0.15), + 'b_f': (2.4, 2.6), + 'b_r': (2.0, 8.0), 'blend_ratio_lb': (0.4, 0.4), 'blend_ratio_ub': (0.5, 0.5), 'c_d': (0.01, 0.01), 'c_f': (1.2, 1.2), - 'c_m_1': (6., 11.), - 'c_m_2': (1.1, 1.8), + 'c_m_1': (6., 10.), + 'c_m_2': (1.0, 1.8), 'c_r': (1.27, 1.27), 'd_f': (0.02, 0.02), 'd_r': (0.017, 0.017), 'i_com': (0.01, 0.1), - 'steering_limit': (0.25, 0.45), + 'steering_limit': (0.15, 0.4), } DEFAULT_PARAMS_BLEND_CAR2: Dict = { @@ -138,20 +138,20 @@ 'm': 1.65, 'l_f': 0.13, 'l_r': 0.17, - 'angle_offset': 0.02611047, - 'b_f': 2.5943623, - 'b_r': 5.2826314, - 'blend_ratio_lb': 0.0005, - 'blend_ratio_ub': 0.012, + 'angle_offset': 0.0, + 'b_f': 2.75, + 'b_r': 5.0, + 'blend_ratio_lb': 0.001, + 'blend_ratio_ub': 0.017, 'c_d': 0.0, - 'c_f': 1.294, - 'c_m_1': 8.9, - 'c_m_2': 1.38, - 'c_r': 0.911, - 'd_f': 0.43, - 'd_r': 0.28, - 'i_com': 0.048, - 'steering_limit': 0.7, + 'c_f': 1.45, + 'c_m_1': 8.2, + 'c_m_2': 1.25, + 'c_r': 1.3, + 'd_f': 0.4, + 'd_r': 0.3, + 'i_com': 0.06, + 'steering_limit': 0.6, } BOUNDS_PARAMS_BLEND_CAR2 = { @@ -159,18 +159,18 @@ 'm': (1.6, 1.7), 'l_f': (0.125, 0.135), 'l_r': (0.165, 0.175), - 'angle_offset': (0.0, 0.035), - 'b_f': (2.5, 3.5), - 'b_r': (4.0, 10.0), + 'angle_offset': (-0.15, 0.15), + 'b_f': (2.0, 4.0), + 'b_r': (3.0, 10.0), 'blend_ratio_lb': (0.0001, 0.1), - 'blend_ratio_ub': (0.0001, 0.1), + 'blend_ratio_ub': (0.0001, 0.2), 'c_d': (0.0, 0.0), - 'c_f': (1.1, 1.5), - 'c_m_1': (7., 10.), - 'c_m_2': (1.1, 1.5), - 'c_r': (0.4, 1.3), - 'd_f': (0.3, 0.6), + 'c_f': (1.1, 2.0), + 'c_m_1': (6.5, 10.0), + 'c_m_2': (1.0, 1.5), + 'c_r': (0.4, 2.0), + 'd_f': (0.25, 0.6), 'd_r': (0.15, 0.45), - 'i_com': (0.03, 0.07), - 'steering_limit': (0.6, 0.8), + 'i_com': (0.03, 0.18), + 'steering_limit': (0.4, 0.75), } \ No newline at end of file diff --git a/sim_transfer/sims/visualize_racecar.py b/sim_transfer/sims/visualize_racecar.py index 49effdf3..3474f187 100644 --- a/sim_transfer/sims/visualize_racecar.py +++ b/sim_transfer/sims/visualize_racecar.py @@ -2,8 +2,10 @@ import jax.numpy as jnp import jax from matplotlib import pyplot as plt +from sim_transfer.sims.envs import RCCarSimEnv +from sim_transfer.sims.util import decode_angles, encode_angles -sim_lf = RaceCarSim(use_blend=False) +sim_lf = RaceCarSim(use_blend=False, car_id=2) sim_hf = RaceCarSim(use_blend=True, car_id=2) @@ -12,6 +14,9 @@ [-1, -0.5, 0., 1.0, 0.5, 0.5, 1.], [0.5, -1.5, 0., -2.0, -0.5, -0.5, 1.],]) +ACTIONS = [lambda t: jnp.array([- 1 * jnp.sin(2 * t), 0.8 / (t + 1)]), + lambda t: jnp.array([+ 1 * jnp.sin(4 * t), 0.8 / (t + 1)]), + lambda t: jnp.array([- 1, 0.8 / (t + 1)])] fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(15, 15)) for k in range(3): @@ -26,12 +31,7 @@ actions = [] for i in range(30): t = i / 30. - if k == 0: - a = jnp.array([- 1 * jnp.sin(2 * t), 0.8 / (t + 1)]) - elif k == 1: - a = jnp.array([+ 1 * jnp.sin(4 * t), 0.8 / (t + 1)]) - elif k == 2: - a = jnp.array([- 1, 0.8 / (t + 1)]) + a = ACTIONS[k](t) a = jnp.repeat(a[None, :], NUM_PARALLEL, axis=0) x = jnp.concatenate([s, a], axis=-1) s = fun_stacked(x) @@ -40,12 +40,23 @@ traj = jnp.stack(traj, axis=0) actions = jnp.stack(actions, axis=0) - from matplotlib import pyplot as plt for i in range(NUM_PARALLEL): axes[k][j].plot(traj[:, i, 0], traj[:, i, 1]) axes[k][j].set_xlim(-1, 1.) axes[k][j].set_ylim(-2, 1.) + env = RCCarSimEnv(encode_angle=True, use_obs_noise=False, use_tire_model=bool(j)) + obs = env.reset() + env._state = decode_angles(INIT_STATE[k], angle_idx=2) + traj_env = [encode_angles(env._state, angle_idx=2)] + for i in range(30): + t = i / 30. + a = ACTIONS[k](t) + obs, _, _, _ = env.step(a) + traj_env.append(obs) + traj_env = jnp.stack(traj_env, axis=0) + axes[k][j].plot(traj_env[:, 0], traj_env[:, 1], color='black', linewidth=3) + axes[k][2].plot(jnp.arange(len(actions[:, 0, 0])), actions[:, 0, 0]) fig.show() \ No newline at end of file