Skip to content

Commit

Permalink
changes to data provider for offline rl, hyperparams for offline and …
Browse files Browse the repository at this point in the history
…online rl
  • Loading branch information
sukhijab committed Jan 4, 2024
1 parent 68e105c commit 65603eb
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
14 changes: 10 additions & 4 deletions experiments/data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,20 @@ def get_rccar_recorded_data_new(encode_angle: bool = True, skip_first_n_points:

# load and shuffle transitions
transitions = _load_transitions(file_names)
indices = jax.random.permutation(key=jax.random.PRNGKey(9345), x=jnp.arange(0, len(transitions)))
transitions = [transitions[idx] for idx in indices]
# indices = jax.random.permutation(key=jax.random.PRNGKey(9345), x=jnp.arange(0, len(transitions)))
# transitions = [transitions[idx] for idx in indices]

# transform transitions into supervised learning datasets
prep_fn = partial(_rccar_transitions_to_dataset, encode_angles=encode_angle, skip_first_n=skip_first_n_points,
action_delay=action_delay, action_stacking=action_stacking)
x_train, y_train = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[:num_train_traj])))
x_test, y_test = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[num_train_traj:])))
x, y = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions)))
# x_test, y_test = map(lambda x: jnp.concatenate(x, axis=0), zip(*map(prep_fn, transitions[num_train_traj:])))
indices = jnp.arange(start=0, stop=x.shape[0], step=1)
indices = jax.random.shuffle(key=jax.random.PRNGKey(9345), x=indices)
x, y = x[indices], y[indices]
num_test_points = 20_000
x_train, y_train, x_test, y_test = x[:-num_test_points], y[:-num_test_points], \
x[-num_test_points:], y[-num_test_points:]
return x_train, y_train, x_test, y_test


Expand Down
20 changes: 13 additions & 7 deletions experiments/offline_rl_from_recorded_data/launcher.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import exp
from experiments.util import generate_run_commands, generate_base_command, dict_permutations

PROJECT_NAME = 'OfflineRLRunsGreyBoxHWFinal'
PROJECT_NAME = 'OfflineRLRunsGreyHW'

_applicable_configs = {
'horizon_len': [200],
'seed': list(range(5)),
'project_name': [PROJECT_NAME],
'sac_num_env_steps': [1_000_000],
'sac_num_env_steps': [2_000_000],
'num_epochs': [50],
'max_train_steps': [100_000],
'max_train_steps': [40_000],
'min_train_steps': [40_000],
'learnable_likelihood_std': ['yes'],
'include_aleatoric_noise': [1],
Expand All @@ -23,7 +23,7 @@
'eval_on_all_offline_data': [1],
'eval_only_on_init_states': [1],
'share_of_x0s_in_sac_buffer': [0.5],
'bnn_batch_size': [32], # for HW 8 worked the best
'bnn_batch_size': [32],
'likelihood_exponent': [0.5],
'train_sac_only_from_init_states': [0],
'data_from_simulation': [0],
Expand Down Expand Up @@ -87,11 +87,17 @@
# _applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + dict_permutations(
# _applicable_configs_grey_box)

all_flags_combinations = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations(
sim_flags = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations(
_applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + \
dict_permutations(_applicable_configs_grey_box_high_fidelity) + \
dict_permutations(_applicable_configs_sim_model_high_fidelity)
dict_permutations(_applicable_configs_grey_box_low_fidelity) + \
dict_permutations(_applicable_configs_sim_model_low_fidelity)

hw_flags = dict_permutations(_applicable_configs_no_sim_prior) + dict_permutations(
_applicable_configs_high_fidelity) + dict_permutations(_applicable_configs_low_fidelity) + \
dict_permutations(_applicable_configs_grey_box_high_fidelity) + \
dict_permutations(_applicable_configs_sim_model_high_fidelity)

all_flags_combinations = sim_flags


def main():
Expand Down
9 changes: 5 additions & 4 deletions experiments/online_rl_hardware/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@

def main(args):
_applicable_configs = {
'prior': ['none_FVSGD', 'none_SVGD', 'high_fidelity', 'low_fidelity'],
'prior': ['none_FVSGD', 'high_fidelity', 'low_fidelity',
'low_fidelity_grey_box'],
'seed': list(range(5)),
'machine': ['local'],
'gpu': [1],
'project_name': ['OnlineRLFewSteps'],
'project_name': ['OnlineRLTestFull'],
'reset_bnn': [1],
'deterministic_policy': [1],
'initial_state_fraction': [0.5],
'bnn_train_steps': [40_000],
'sac_num_env_steps': [250_000],
'num_sac_envs': [64],
'sac_num_env_steps': [500_000],
'num_sac_envs': [128],
'num_env_steps': [100],
'num_f_samples': [512]
}
Expand Down

0 comments on commit 65603eb

Please sign in to comment.