diff --git a/mbpo/optimizers/policy_optimizers/brax_optimizers.py b/mbpo/optimizers/policy_optimizers/brax_optimizers.py index a652259..d0f8d3b 100644 --- a/mbpo/optimizers/policy_optimizers/brax_optimizers.py +++ b/mbpo/optimizers/policy_optimizers/brax_optimizers.py @@ -102,6 +102,7 @@ def train(self, class PPOOptimizer(BraxOptimizer): def __init__(self, true_buffer: UniformSamplingQueue, + system: System | None = None, **ppo_kwargs): super().__init__(agent_class=PPO, system=system, true_buffer=true_buffer, **ppo_kwargs)