Skip to content

Commit

Permalink
Add compiled flag to OptimizerWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 3, 2024
1 parent ebd4756 commit bc3d1a1
Show file tree
Hide file tree
Showing 21 changed files with 236 additions and 82 deletions.
12 changes: 9 additions & 3 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class AWAC(QLearningAlgoBase[AWACImpl, AWACConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -132,10 +134,14 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)

dummy_log_temp = Parameter(torch.zeros(1, 1))
Expand All @@ -160,7 +166,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down
8 changes: 6 additions & 2 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def inner_create_impl(
raise ValueError(f"invalid policy_type: {self._config.policy_type}")

optim = self._config.optim_factory.create(
imitator.named_modules(), lr=self._config.learning_rate
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=False,
)

modules = BCModules(optim=optim, imitator=imitator)
Expand Down Expand Up @@ -168,7 +170,9 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
imitator.named_modules(), lr=self._config.learning_rate
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=False,
)

modules = DiscreteBCModules(optim=optim, imitator=imitator)
Expand Down
21 changes: 16 additions & 5 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ class BCQ(QLearningAlgoBase[BCQImpl, BCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_deterministic_residual_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -230,15 +232,20 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=compiled,
)

modules = BCQModules(
Expand Down Expand Up @@ -266,7 +273,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down Expand Up @@ -364,6 +371,8 @@ class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
Expand Down Expand Up @@ -407,7 +416,9 @@ def inner_create_impl(
q_func_params = list(q_funcs.named_modules())
imitator_params = list(imitator.named_modules())
optim = self._config.optim_factory.create(
q_func_params + imitator_params, lr=self._config.learning_rate
q_func_params + imitator_params,
lr=self._config.learning_rate,
compiled=compiled,
)

modules = DiscreteBCQModules(
Expand All @@ -427,7 +438,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down
21 changes: 16 additions & 5 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ class BEAR(QLearningAlgoBase[BEARImpl, BEARConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -219,21 +221,30 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=compiled,
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=compiled,
)
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.actor_learning_rate
log_alpha.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)

modules = BEARModules(
Expand Down Expand Up @@ -268,7 +279,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down
19 changes: 14 additions & 5 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
Expand Down Expand Up @@ -129,20 +130,28 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -172,7 +181,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down
27 changes: 20 additions & 7 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
Expand Down Expand Up @@ -184,20 +185,28 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -227,7 +236,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down Expand Up @@ -303,6 +312,8 @@ class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

q_funcs, q_func_forwarder = create_discrete_q_function(
observation_shape,
action_size,
Expand All @@ -323,7 +334,9 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.named_modules(), lr=self._config.learning_rate
q_funcs.named_modules(),
lr=self._config.learning_rate,
compiled=compiled,
)

modules = DQNModules(
Expand All @@ -341,7 +354,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down
12 changes: 9 additions & 3 deletions d3rlpy/algos/qlearning/crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class CRR(QLearningAlgoBase[CRRImpl, CRRConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -171,10 +173,14 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)

modules = CRRModules(
Expand All @@ -201,7 +207,7 @@ def inner_create_impl(
tau=self._config.tau,
target_update_type=self._config.target_update_type,
target_update_interval=self._config.target_update_interval,
compile_graph=self._config.compile_graph and "cuda" in self._device,
compile_graph=compiled,
device=self._device,
)

Expand Down
12 changes: 9 additions & 3 deletions d3rlpy/algos/qlearning/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class DDPG(QLearningAlgoBase[DDPGImpl, DDPGConfig]):
def inner_create_impl(
self, observation_shape: Shape, action_size: int
) -> None:
compiled = self._config.compile_graph and "cuda" in self._device

policy = create_deterministic_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -132,10 +134,14 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=compiled,
)

modules = DDPGModules(
Expand All @@ -155,7 +161,7 @@ def inner_create_impl(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=self._config.gamma,
tau=self._config.tau,
compile_graph=self._config.compile_graph,
compile_graph=compiled,
device=self._device,
)

Expand Down
Loading

0 comments on commit bc3d1a1

Please sign in to comment.