diff --git a/experiments/hardware_experiments/api_experiments/evaluate_offline_rl_on_hardware.py b/experiments/hardware_experiments/api_experiments/evaluate_offline_rl_on_hardware.py index 985999a0..426fce9c 100644 --- a/experiments/hardware_experiments/api_experiments/evaluate_offline_rl_on_hardware.py +++ b/experiments/hardware_experiments/api_experiments/evaluate_offline_rl_on_hardware.py @@ -19,7 +19,7 @@ class RunSpec(NamedTuple): group_name: str run_id: str - reward_config: dict| None = None + reward_config: dict | None = None def run_all_hardware_experiments(project_name_load: str, @@ -27,6 +27,7 @@ def run_all_hardware_experiments(project_name_load: str, desired_config: dict | None = None, control_time_ms: float = 32, download_data: bool = True, + use_grey_box: bool = False ): api = wandb.Api() project_name = ENTITY + '/' + project_name_load @@ -45,9 +46,14 @@ def run_all_hardware_experiments(project_name_load: str, # Download all models runs = api.runs(project_name) + i = 0 for run in runs: config = {k: v for k, v in run.config.items() if not k.startswith('_')} correct_config = 1 + group_name = run.group + if not use_grey_box: + if 'use_grey_box=1' in group_name: + continue if desired_config: for key in desired_config.keys(): if config[key] != desired_config[key]: @@ -55,10 +61,13 @@ def run_all_hardware_experiments(project_name_load: str, break if not correct_config: continue + i += 1 + print(i) + print('group_name: ', group_name) for file in run.files(): if file.name.startswith(dir_to_save): file.download(replace=True, root=os.path.join(local_dir, run.group, run.id)) - + print('data_from_sim: ', config['data_from_simulation']) keys = ['encode_angle', 'ctrl_cost_weight', 'margin_factor', 'ctrl_diff_weight'] reward_config = {} for key in keys: @@ -79,18 +88,19 @@ def run_all_hardware_experiments(project_name_load: str, policy_name = 'models/parameters.pkl' bnn_name = 'models/bnn_model.pkl' - with open(os.path.join(pre_path, bnn_name), 'rb') as handle: - bnn_model = pickle.load(handle) + import cloudpickle + # with open(os.path.join(pre_path, bnn_name), 'rb') as handle: + # bnn_model = cloudpickle.load(handle) with open(os.path.join(pre_path, policy_name), 'rb') as handle: - policy_params = pickle.load(handle) + policy_params = cloudpickle.load(handle) run_with_learned_policy(policy_params=policy_params, - bnn_model=bnn_model, + bnn_model=None, project_name=project_name_save, group_name=run_spec.group_name, run_id=run_spec.run_id, - reward_config=reward_config, + reward_config=run_spec.reward_config, control_time_ms=control_time_ms) @@ -124,7 +134,7 @@ def run_with_learned_policy(policy_params, project=project_name, group=group_name, entity=ENTITY, - id=run_id, + id=run_id + 'f', resume="allow", ) policy = rl_from_offline_data.prepare_policy(params=policy_params) @@ -141,8 +151,6 @@ def run_with_learned_policy(policy_params, t_prev = time.time() actions = [] - - stacked_actions = jnp.zeros(shape=(num_frame_stack * action_dim,)) time_diffs = [] @@ -168,10 +176,11 @@ def run_with_learned_policy(policy_params, t = time.time() time_diff = t - t_prev t_prev = t - print(i, action, reward, time_diff) + # print(i, action, reward, time_diff) time_diffs.append(time_diff) stacked_actions = obs[state_dim:] observations.append(obs) + print(obs) rewards.append(reward) all_stacked_actions.append(stacked_actions) @@ -190,7 +199,7 @@ def run_with_learned_policy(policy_params, if time_diff_std > 0.001: Warning('Variability in time difference is too high') - if abs(mean_time_diff - 1/30.) < 0.001: + if abs(mean_time_diff - 1 / 30.) < 0.001: Warning('Control frequency is not maintained with the time difference') plt.plot(time_diffs[1:]) plt.title('time diffs') @@ -225,15 +234,15 @@ def run_with_learned_policy(policy_params, We test the model error on the predicted trajectory """ print('We test the model error on the predicted trajectory') - all_outputs = bnn_model.predict_dist(all_inputs, include_noise=False) + # all_outputs = bnn_model.predict_dist(all_inputs, include_noise=False) sim_key, subkey = jr.split(sim_key) - delta_x = all_outputs.sample(seed=subkey) + # delta_x = all_outputs.sample(seed=subkey) - assert delta_x.shape == target_outputs.shape + # assert delta_x.shape == target_outputs.shape # Random data for demonstration purposes - data = delta_x - target_outputs - fig, axes = plot_error_on_the_trajectory(data) - wandb.log({'Error of state difference prediction': wandb.Image(fig)}) + # data = delta_x - target_outputs + # fig, axes = plot_error_on_the_trajectory(data) + # wandb.log({'Error of state difference prediction': wandb.Image(fig)}) # We plot the true trajectory fig, axes = plot_rc_trajectory(observations[:, :state_dim], @@ -241,33 +250,33 @@ def run_with_learned_policy(policy_params, encode_angle=encode_angle, show=True) wandb.log({'Trajectory_on_true_model': wandb.Image(fig)}) - sim_obs = sim_obs[:state_dim] - for i in range(200): - obs = jnp.concatenate([sim_obs, sim_stacked_actions], axis=0) - sim_action = policy(obs) - # sim_action = np.array(sim_action) - z = jnp.concatenate([obs, sim_action], axis=-1) - z = z.reshape(1, -1) - delta_x_dist = bnn_model.predict_dist(z, include_noise=True) - sim_key, subkey = jr.split(sim_key) - delta_x = delta_x_dist.sample(seed=subkey) - sim_obs = sim_obs + delta_x.reshape(-1) - - # Now we shift the actions - old_sim_stacked_actions = sim_stacked_actions - sim_stacked_actions.at[:-action_dim].set(old_sim_stacked_actions[action_dim:]) - sim_stacked_actions = sim_stacked_actions.at[-action_dim:].set(sim_action) - all_sim_actions.append(sim_action) - all_sim_obs.append(sim_obs) - all_sim_stacked_actions.append(sim_stacked_actions) - - sim_observations_for_plotting = np.stack(all_sim_obs, axis=0) - sim_actions_for_plotting = np.stack(all_sim_actions, axis=0) - fig, axes = plot_rc_trajectory(sim_observations_for_plotting, - sim_actions_for_plotting, - encode_angle=encode_angle, - show=True) - wandb.log({'Trajectory_on_learned_model': wandb.Image(fig)}) + # sim_obs = sim_obs[:state_dim] + # for i in range(200): + # obs = jnp.concatenate([sim_obs, sim_stacked_actions], axis=0) + # sim_action = policy(obs) + # # sim_action = np.array(sim_action) + # z = jnp.concatenate([obs, sim_action], axis=-1) + # z = z.reshape(1, -1) + # delta_x_dist = bnn_model.predict_dist(z, include_noise=True) + # sim_key, subkey = jr.split(sim_key) + # delta_x = delta_x_dist.sample(seed=subkey) + # sim_obs = sim_obs + delta_x.reshape(-1) + # + # # Now we shift the actions + # old_sim_stacked_actions = sim_stacked_actions + # sim_stacked_actions.at[:-action_dim].set(old_sim_stacked_actions[action_dim:]) + # sim_stacked_actions = sim_stacked_actions.at[-action_dim:].set(sim_action) + # all_sim_actions.append(sim_action) + # all_sim_obs.append(sim_obs) + # all_sim_stacked_actions.append(sim_stacked_actions) + # + # sim_observations_for_plotting = np.stack(all_sim_obs, axis=0) + # sim_actions_for_plotting = np.stack(all_sim_actions, axis=0) + # fig, axes = plot_rc_trajectory(sim_observations_for_plotting, + # sim_actions_for_plotting, + # encode_angle=encode_angle, + # show=True) + # wandb.log({'Trajectory_on_learned_model': wandb.Image(fig)}) wandb.finish() return observations, actions @@ -296,13 +305,80 @@ def plot_error_on_the_trajectory(data): return fig, axes +def evaluate_runs_for_video(num_data_points: int = 50, control_time_ms: float = 28.2): + if num_data_points == 50: + policy_filenames = [ + 'saved_data/use_sim_prior=0_use_grey_box=0_high_fidelity=0_num_offline_data' + '=50_share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5'\ + '/z4rlsbwj/models/parameters.pkl', + + 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data=50_share_of_x0s=' + '0.5_sac_only_from_is=0_use_sim_model=0_0.5/j5ekgar9/models/parameters.pkl', + + 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=1_num_offline_data=50_' + 'share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5/3c0m3gpi' + '/models/parameters.pkl' + + ] + elif num_data_points == 800: + policy_filenames = [ + 'saved_data/use_sim_prior=0_use_grey_box=0_high_fidelity=0_num_offline_data' + '=800_share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5' \ + '/59ufk8pc/models/parameters.pkl', + + 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data=800_share_of_x0s=' + '0.5_sac_only_from_is=0_use_sim_model=0_0.5/sy4devzz/models/parameters.pkl', + + 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=1_num_offline_data=800_share_of_x0s' + '=0.5_sac_only_from_is=0_use_sim_model=0_0.5/vkybm8ml/models/parameters.pkl' + ] + elif num_data_points == 2500: + policy_filenames = [ + 'saved_data/use_sim_prior=0_use_grey_box=0_high_fidelity=0_num_offline_data' + '=2500_share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5' \ + '/tppyk230/models/parameters.pkl', + + 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data=2500_share_of_x0s=' + '0.5_sac_only_from_is=0_use_sim_model=0_0.5/xvvy2wde/models/parameters.pkl', + + 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=1_num_offline_data=2500_share_of_x0s' + '=0.5_sac_only_from_is=0_use_sim_model=0_0.5/9i1k6d07/models/parameters.pkl' + ] + else: + raise AssertionError + + dummy_reward_kwargs = { + 'ctrl_cost_weight': 0.005, + 'margin_factor': 20.0, + 'ctrl_diff_weight': 0.0, + 'encode_angle': True, + } + + for i, param_file in enumerate(policy_filenames): + import cloudpickle + # with open(os.path.join(pre_path, bnn_name), 'rb') as handle: + # bnn_model = cloudpickle.load(handle) + + with open(param_file, 'rb') as handle: + policy_params = cloudpickle.load(handle) + + run_with_learned_policy(policy_params=policy_params, + bnn_model=None, + project_name='dummy_proj', + group_name='dummy', + run_id=f'dummy{i}', + reward_config=dummy_reward_kwargs, + control_time_ms=control_time_ms) + + + if __name__ == '__main__': import pickle filename_policy = 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data' \ '=2500_share_of_x0s=0.5_train_sac_only_from_init_states=0_0.5/tshlnhs0/models/parameters.pkl' filename_bnn_model = 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data' \ - '=2500_share_of_x0s=0.5_train_sac_only_from_init_states=0_0.5/tshlnhs0/models/bnn_model.pkl' + '=2500_share_of_x0s=0.5_train_sac_only_from_inyit_states=0_0.5/tshlnhs0/models/bnn_model.pkl' # with open(filename_policy, 'rb') as handle: # policy_params = pickle.load(handle) @@ -317,10 +393,13 @@ def plot_error_on_the_trajectory(data): # run_id='Butterfly', # control_time_ms=32, # ) - - run_all_hardware_experiments( - project_name_load='OfflineRLRunsWithGreyBox', - project_name_save='OfflineRLRunsWithGreyBox_evaluation', - desired_config={'bandwidth_svgd': 0.2, 'data_from_simulation': 0}, - control_time_ms=32, - ) + evaluate_runs_for_video(num_data_points=2500, control_time_ms=28.5) + # run_all_hardware_experiments( + # project_name_load='OfflineRLRunsGreyBoxHW2', + # project_name_save='HWEvaluationGb2', + # desired_config={'bandwidth_svgd': 0.2, 'data_from_simulation': 0, 'num_offline_collected_transitions': 5000, + # 'use_sim_model': 0}, + # control_time_ms=28.5, + # download_data=True, + # use_grey_box=True, + #)