Skip to content

Commit

Permalink
include grey box and sim models for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Nov 30, 2023
1 parent 4fcb13f commit e6321aa
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
4 changes: 4 additions & 0 deletions experiments/online_rl_hardware/online_rl_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
'high_fidelity',
'low_fidelity',
'high_fidelity_no_aditive_GP',
'high_fidelity_grey_box',
'low_fidelity_grey_box',
'high_fidelity_sim',
'low_fidelity_sim'
}


Expand Down
24 changes: 23 additions & 1 deletion experiments/online_rl_hardware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from brax.training.replay_buffers import ReplayBuffer, ReplayBufferState

from sim_transfer.rl.model_based_rl.learned_system import LearnedCarSystem
from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_FSVGD, BNN_SVGD
from sim_transfer.models import BNN_FSVGD_SimPrior, BNN_FSVGD, BNN_SVGD, BNNGreyBox
from sim_transfer.sims.simulators import AdditiveSim, PredictStateChangeWrapper, GaussianProcessSim
from sim_transfer.sims.simulators import RaceCarSim, StackedActionSimWrapper
from sim_transfer.sims.envs import RCCarSimEnv
Expand Down Expand Up @@ -169,6 +169,28 @@ def set_up_bnn_dynamics_model(config: Any, key: jax.random.PRNGKey):
**standard_params,
bandwidth_svgd=1.0,
)
elif config.sim_prior == 'high_fidelity_grey_box' or config.sim_prior == 'low_fidelity_grey_box':
base_bnn = BNN_FSVGD(
**standard_params,
domain=sim.domain,
bandwidth_svgd=config.bandwidth_svgd,
)
bnn = BNNGreyBox(
base_bnn=base_bnn,
sim=sim,
use_base_bnn=True,
)
elif config.sim_prior == 'high_fidelity_sim' or config.sim_prior == 'low_fidelity_sim':
base_bnn = BNN_FSVGD(
**standard_params,
domain=sim.domain,
bandwidth_svgd=config.bandwidth_svgd,
)
bnn = BNNGreyBox(
base_bnn=base_bnn,
sim=sim,
use_base_bnn=False,
)
elif config.sim_prior == 'high_fidelity_no_aditive_GP':
bnn = BNN_FSVGD_SimPrior(
**standard_params,
Expand Down

0 comments on commit e6321aa

Please sign in to comment.