Skip to content

Commit

Permalink
working offline rl evaluation file for HW
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasrothfuss committed Jan 4, 2024
1 parent 8109fe9 commit 781d697
Showing 1 changed file with 132 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
class RunSpec(NamedTuple):
group_name: str
run_id: str
reward_config: dict| None = None
reward_config: dict | None = None


def run_all_hardware_experiments(project_name_load: str,
project_name_save: str | None = None,
desired_config: dict | None = None,
control_time_ms: float = 32,
download_data: bool = True,
use_grey_box: bool = False
):
api = wandb.Api()
project_name = ENTITY + '/' + project_name_load
Expand All @@ -45,20 +46,28 @@ def run_all_hardware_experiments(project_name_load: str,

# Download all models
runs = api.runs(project_name)
i = 0
for run in runs:
config = {k: v for k, v in run.config.items() if not k.startswith('_')}
correct_config = 1
group_name = run.group
if not use_grey_box:
if 'use_grey_box=1' in group_name:
continue
if desired_config:
for key in desired_config.keys():
if config[key] != desired_config[key]:
correct_config = 0
break
if not correct_config:
continue
i += 1
print(i)
print('group_name: ', group_name)
for file in run.files():
if file.name.startswith(dir_to_save):
file.download(replace=True, root=os.path.join(local_dir, run.group, run.id))

print('data_from_sim: ', config['data_from_simulation'])
keys = ['encode_angle', 'ctrl_cost_weight', 'margin_factor', 'ctrl_diff_weight']
reward_config = {}
for key in keys:
Expand All @@ -79,18 +88,19 @@ def run_all_hardware_experiments(project_name_load: str,
policy_name = 'models/parameters.pkl'
bnn_name = 'models/bnn_model.pkl'

with open(os.path.join(pre_path, bnn_name), 'rb') as handle:
bnn_model = pickle.load(handle)
import cloudpickle
# with open(os.path.join(pre_path, bnn_name), 'rb') as handle:
# bnn_model = cloudpickle.load(handle)

with open(os.path.join(pre_path, policy_name), 'rb') as handle:
policy_params = pickle.load(handle)
policy_params = cloudpickle.load(handle)

run_with_learned_policy(policy_params=policy_params,
bnn_model=bnn_model,
bnn_model=None,
project_name=project_name_save,
group_name=run_spec.group_name,
run_id=run_spec.run_id,
reward_config=reward_config,
reward_config=run_spec.reward_config,
control_time_ms=control_time_ms)


Expand Down Expand Up @@ -124,7 +134,7 @@ def run_with_learned_policy(policy_params,
project=project_name,
group=group_name,
entity=ENTITY,
id=run_id,
id=run_id + 'f',
resume="allow",
)
policy = rl_from_offline_data.prepare_policy(params=policy_params)
Expand All @@ -141,8 +151,6 @@ def run_with_learned_policy(policy_params,
t_prev = time.time()
actions = []



stacked_actions = jnp.zeros(shape=(num_frame_stack * action_dim,))
time_diffs = []

Expand All @@ -168,10 +176,11 @@ def run_with_learned_policy(policy_params,
t = time.time()
time_diff = t - t_prev
t_prev = t
print(i, action, reward, time_diff)
# print(i, action, reward, time_diff)
time_diffs.append(time_diff)
stacked_actions = obs[state_dim:]
observations.append(obs)
print(obs)
rewards.append(reward)
all_stacked_actions.append(stacked_actions)

Expand All @@ -190,7 +199,7 @@ def run_with_learned_policy(policy_params,

if time_diff_std > 0.001:
Warning('Variability in time difference is too high')
if abs(mean_time_diff - 1/30.) < 0.001:
if abs(mean_time_diff - 1 / 30.) < 0.001:
Warning('Control frequency is not maintained with the time difference')
plt.plot(time_diffs[1:])
plt.title('time diffs')
Expand Down Expand Up @@ -225,49 +234,49 @@ def run_with_learned_policy(policy_params,
We test the model error on the predicted trajectory
"""
print('We test the model error on the predicted trajectory')
all_outputs = bnn_model.predict_dist(all_inputs, include_noise=False)
# all_outputs = bnn_model.predict_dist(all_inputs, include_noise=False)
sim_key, subkey = jr.split(sim_key)
delta_x = all_outputs.sample(seed=subkey)
# delta_x = all_outputs.sample(seed=subkey)

assert delta_x.shape == target_outputs.shape
# assert delta_x.shape == target_outputs.shape
# Random data for demonstration purposes
data = delta_x - target_outputs
fig, axes = plot_error_on_the_trajectory(data)
wandb.log({'Error of state difference prediction': wandb.Image(fig)})
# data = delta_x - target_outputs
# fig, axes = plot_error_on_the_trajectory(data)
# wandb.log({'Error of state difference prediction': wandb.Image(fig)})

# We plot the true trajectory
fig, axes = plot_rc_trajectory(observations[:, :state_dim],
actions,
encode_angle=encode_angle,
show=True)
wandb.log({'Trajectory_on_true_model': wandb.Image(fig)})
sim_obs = sim_obs[:state_dim]
for i in range(200):
obs = jnp.concatenate([sim_obs, sim_stacked_actions], axis=0)
sim_action = policy(obs)
# sim_action = np.array(sim_action)
z = jnp.concatenate([obs, sim_action], axis=-1)
z = z.reshape(1, -1)
delta_x_dist = bnn_model.predict_dist(z, include_noise=True)
sim_key, subkey = jr.split(sim_key)
delta_x = delta_x_dist.sample(seed=subkey)
sim_obs = sim_obs + delta_x.reshape(-1)

# Now we shift the actions
old_sim_stacked_actions = sim_stacked_actions
sim_stacked_actions.at[:-action_dim].set(old_sim_stacked_actions[action_dim:])
sim_stacked_actions = sim_stacked_actions.at[-action_dim:].set(sim_action)
all_sim_actions.append(sim_action)
all_sim_obs.append(sim_obs)
all_sim_stacked_actions.append(sim_stacked_actions)

sim_observations_for_plotting = np.stack(all_sim_obs, axis=0)
sim_actions_for_plotting = np.stack(all_sim_actions, axis=0)
fig, axes = plot_rc_trajectory(sim_observations_for_plotting,
sim_actions_for_plotting,
encode_angle=encode_angle,
show=True)
wandb.log({'Trajectory_on_learned_model': wandb.Image(fig)})
# sim_obs = sim_obs[:state_dim]
# for i in range(200):
# obs = jnp.concatenate([sim_obs, sim_stacked_actions], axis=0)
# sim_action = policy(obs)
# # sim_action = np.array(sim_action)
# z = jnp.concatenate([obs, sim_action], axis=-1)
# z = z.reshape(1, -1)
# delta_x_dist = bnn_model.predict_dist(z, include_noise=True)
# sim_key, subkey = jr.split(sim_key)
# delta_x = delta_x_dist.sample(seed=subkey)
# sim_obs = sim_obs + delta_x.reshape(-1)
#
# # Now we shift the actions
# old_sim_stacked_actions = sim_stacked_actions
# sim_stacked_actions.at[:-action_dim].set(old_sim_stacked_actions[action_dim:])
# sim_stacked_actions = sim_stacked_actions.at[-action_dim:].set(sim_action)
# all_sim_actions.append(sim_action)
# all_sim_obs.append(sim_obs)
# all_sim_stacked_actions.append(sim_stacked_actions)
#
# sim_observations_for_plotting = np.stack(all_sim_obs, axis=0)
# sim_actions_for_plotting = np.stack(all_sim_actions, axis=0)
# fig, axes = plot_rc_trajectory(sim_observations_for_plotting,
# sim_actions_for_plotting,
# encode_angle=encode_angle,
# show=True)
# wandb.log({'Trajectory_on_learned_model': wandb.Image(fig)})
wandb.finish()
return observations, actions

Expand Down Expand Up @@ -296,13 +305,80 @@ def plot_error_on_the_trajectory(data):
return fig, axes


def evaluate_runs_for_video(num_data_points: int = 50, control_time_ms: float = 28.2):
if num_data_points == 50:
policy_filenames = [
'saved_data/use_sim_prior=0_use_grey_box=0_high_fidelity=0_num_offline_data'
'=50_share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5'\
'/z4rlsbwj/models/parameters.pkl',

'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data=50_share_of_x0s='
'0.5_sac_only_from_is=0_use_sim_model=0_0.5/j5ekgar9/models/parameters.pkl',

'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=1_num_offline_data=50_'
'share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5/3c0m3gpi'
'/models/parameters.pkl'

]
elif num_data_points == 800:
policy_filenames = [
'saved_data/use_sim_prior=0_use_grey_box=0_high_fidelity=0_num_offline_data'
'=800_share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5' \
'/59ufk8pc/models/parameters.pkl',

'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data=800_share_of_x0s='
'0.5_sac_only_from_is=0_use_sim_model=0_0.5/sy4devzz/models/parameters.pkl',

'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=1_num_offline_data=800_share_of_x0s'
'=0.5_sac_only_from_is=0_use_sim_model=0_0.5/vkybm8ml/models/parameters.pkl'
]
elif num_data_points == 2500:
policy_filenames = [
'saved_data/use_sim_prior=0_use_grey_box=0_high_fidelity=0_num_offline_data'
'=2500_share_of_x0s=0.5_sac_only_from_is=0_use_sim_model=0_0.5' \
'/tppyk230/models/parameters.pkl',

'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data=2500_share_of_x0s='
'0.5_sac_only_from_is=0_use_sim_model=0_0.5/xvvy2wde/models/parameters.pkl',

'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=1_num_offline_data=2500_share_of_x0s'
'=0.5_sac_only_from_is=0_use_sim_model=0_0.5/9i1k6d07/models/parameters.pkl'
]
else:
raise AssertionError

dummy_reward_kwargs = {
'ctrl_cost_weight': 0.005,
'margin_factor': 20.0,
'ctrl_diff_weight': 0.0,
'encode_angle': True,
}

for i, param_file in enumerate(policy_filenames):
import cloudpickle
# with open(os.path.join(pre_path, bnn_name), 'rb') as handle:
# bnn_model = cloudpickle.load(handle)

with open(param_file, 'rb') as handle:
policy_params = cloudpickle.load(handle)

run_with_learned_policy(policy_params=policy_params,
bnn_model=None,
project_name='dummy_proj',
group_name='dummy',
run_id=f'dummy{i}',
reward_config=dummy_reward_kwargs,
control_time_ms=control_time_ms)



if __name__ == '__main__':
import pickle

filename_policy = 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data' \
'=2500_share_of_x0s=0.5_train_sac_only_from_init_states=0_0.5/tshlnhs0/models/parameters.pkl'
filename_bnn_model = 'saved_data/use_sim_prior=1_use_grey_box=0_high_fidelity=0_num_offline_data' \
'=2500_share_of_x0s=0.5_train_sac_only_from_init_states=0_0.5/tshlnhs0/models/bnn_model.pkl'
'=2500_share_of_x0s=0.5_train_sac_only_from_inyit_states=0_0.5/tshlnhs0/models/bnn_model.pkl'

# with open(filename_policy, 'rb') as handle:
# policy_params = pickle.load(handle)
Expand All @@ -317,10 +393,13 @@ def plot_error_on_the_trajectory(data):
# run_id='Butterfly',
# control_time_ms=32,
# )

run_all_hardware_experiments(
project_name_load='OfflineRLRunsWithGreyBox',
project_name_save='OfflineRLRunsWithGreyBox_evaluation',
desired_config={'bandwidth_svgd': 0.2, 'data_from_simulation': 0},
control_time_ms=32,
)
evaluate_runs_for_video(num_data_points=2500, control_time_ms=28.5)
# run_all_hardware_experiments(
# project_name_load='OfflineRLRunsGreyBoxHW2',
# project_name_save='HWEvaluationGb2',
# desired_config={'bandwidth_svgd': 0.2, 'data_from_simulation': 0, 'num_offline_collected_transitions': 5000,
# 'use_sim_model': 0},
# control_time_ms=28.5,
# download_data=True,
# use_grey_box=True,
#)

0 comments on commit 781d697

Please sign in to comment.