Skip to content

Commit

Permalink
minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Nov 10, 2023
1 parent 23e8f62 commit 351569a
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions mbpo/optimizers/policy_optimizers/brax_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class BraxOutput(OptimizerTrainingOutPut, Generic[DynamicsParams, RewardParams])
class BraxOptimizer(BaseOptimizer[BraxState, BraxOutput]):
def __init__(self,
agent_class,
system: System,
true_buffer: UniformSamplingQueue,
system: System | None = None,
**agent_kwargs):
super().__init__(system)
self.agent_class = agent_class
Expand All @@ -46,9 +46,9 @@ def __init__(self,
self.set_system(system)

def set_system(self, system: System):
super().set_system(system)
self.key, sys_key, buffer_key = jr.split(self.key, 3)
dummy_true_buffer_state = self.dummy_true_buffer_state(buffer_key)
super().set_system(system)
dummy_env = BraxWrapper(system=self.system,
system_params=self.system.init_params(sys_key),
sample_buffer_state=dummy_true_buffer_state,
Expand Down Expand Up @@ -101,15 +101,14 @@ def train(self,

class PPOOptimizer(BraxOptimizer):
def __init__(self,
system: System,
true_buffer: UniformSamplingQueue,
**ppo_kwargs):
super().__init__(agent_class=PPO, system=system, true_buffer=true_buffer, **ppo_kwargs)


class SACOptimizer(BraxOptimizer):
def __init__(self,
system: System,
true_buffer: UniformSamplingQueue,
system: System | None = None,
**sac_kwargs):
super().__init__(agent_class=SAC, system=system, true_buffer=true_buffer, **sac_kwargs)

0 comments on commit 351569a

Please sign in to comment.