Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CudaGraph and torch.compile #428

Merged
merged 15 commits into from
Nov 3, 2024
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: 2
build:
os: ubuntu-22.04
tools:
python: "3.8"
python: "3.10"
sphinx:
builder: html
configuration: docs/conf.py
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ d3rlpy supports Linux, macOS and Windows.

### Dependencies
Installing d3rlpy package will install or upgrade the following packages to satisfy requirements:
- torch>=2.0.0
- torch>=2.5.0
- tqdm>=4.66.3
- gym>=0.26.0
- gymnasium>=1.0.0
Expand Down
5 changes: 5 additions & 0 deletions d3rlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access
import random

import gymnasium
Expand Down Expand Up @@ -68,6 +69,10 @@
# run healthcheck
run_healthcheck()

if torch.cuda.is_available():
# enable autograd compilation
torch._dynamo.config.compiled_autograd = True
torch.set_float32_matmul_precision("high")

Check warning on line 75 in d3rlpy/__init__.py

View check run for this annotation

Codecov / codecov/patch

d3rlpy/__init__.py#L74-L75

Added lines #L74 - L75 were not covered by tests

# register Shimmy if available
try:
Expand Down
10 changes: 8 additions & 2 deletions d3rlpy/algos/qlearning/awac.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class AWACConfig(LearnableConfig):
n_action_samples (int): Number of sampled actions to calculate
:math:`A^\pi(s_t, a_t)`.
n_critics (int): Number of Q functions for ensemble.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 3e-4
Expand Down Expand Up @@ -130,10 +131,14 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)

dummy_log_temp = Parameter(torch.zeros(1, 1))
Expand All @@ -158,6 +163,7 @@ def inner_create_impl(
tau=self._config.tau,
lam=self._config.lam,
n_action_samples=self._config.n_action_samples,
compiled=self.compiled,
device=self._device,
)

Expand Down
12 changes: 10 additions & 2 deletions d3rlpy/algos/qlearning/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BCConfig(LearnableConfig):
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 100
Expand Down Expand Up @@ -93,7 +94,9 @@ def inner_create_impl(
raise ValueError(f"invalid policy_type: {self._config.policy_type}")

optim = self._config.optim_factory.create(
imitator.named_modules(), lr=self._config.learning_rate
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = BCModules(optim=optim, imitator=imitator)
Expand All @@ -103,6 +106,7 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
policy_type=self._config.policy_type,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -137,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig):
beta (float): Reguralization factor.
observation_scaler (d3rlpy.preprocessing.ObservationScaler):
Observation preprocessor.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

batch_size: int = 100
Expand Down Expand Up @@ -168,7 +173,9 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
imitator.named_modules(), lr=self._config.learning_rate
imitator.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = DiscreteBCModules(optim=optim, imitator=imitator)
Expand All @@ -178,6 +185,7 @@ def inner_create_impl(
action_size=action_size,
modules=modules,
beta=self._config.beta,
compiled=self.compiled,
device=self._device,
)

Expand Down
17 changes: 14 additions & 3 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class BCQConfig(LearnableConfig):
rl_start_step (int): Steps to start to update policy function and Q
functions. If this is large, RL training would be more stabilized.
beta (float): KL reguralization term for Conditional VAE.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-3
Expand Down Expand Up @@ -228,15 +229,20 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=self.compiled,
)

modules = BCQModules(
Expand Down Expand Up @@ -264,6 +270,7 @@ def inner_create_impl(
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
rl_start_step=self._config.rl_start_step,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -331,6 +338,7 @@ class DiscreteBCQConfig(LearnableConfig):
target_update_interval (int): Interval to update the target network.
share_encoder (bool): Flag to share encoder between Q-function and
imitation models.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand Down Expand Up @@ -402,7 +410,9 @@ def inner_create_impl(
q_func_params = list(q_funcs.named_modules())
imitator_params = list(imitator.named_modules())
optim = self._config.optim_factory.create(
q_func_params + imitator_params, lr=self._config.learning_rate
q_func_params + imitator_params,
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = DiscreteBCQModules(
Expand All @@ -422,6 +432,7 @@ def inner_create_impl(
gamma=self._config.gamma,
action_flexibility=self._config.action_flexibility,
beta=self._config.beta,
compiled=self.compiled,
device=self._device,
)

Expand Down
19 changes: 15 additions & 4 deletions d3rlpy/algos/qlearning/bear.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class BEARConfig(LearnableConfig):
policy training.
warmup_steps (int): Number of steps to warmup the policy
function.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand Down Expand Up @@ -217,21 +218,30 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
vae_optim = self._config.imitator_optim_factory.create(
list(vae_encoder.named_modules())
+ list(vae_decoder.named_modules()),
lr=self._config.imitator_learning_rate,
compiled=self.compiled,
)
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=self.compiled,
)
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.actor_learning_rate
log_alpha.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)

modules = BEARModules(
Expand Down Expand Up @@ -266,6 +276,7 @@ def inner_create_impl(
mmd_sigma=self._config.mmd_sigma,
vae_kl_weight=self._config.vae_kl_weight,
warmup_steps=self._config.warmup_steps,
compiled=self.compiled,
device=self._device,
)

Expand Down
19 changes: 14 additions & 5 deletions d3rlpy/algos/qlearning/cal_ql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CalQLConfig(CQLConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

def create(
Expand All @@ -88,7 +89,6 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -128,20 +128,28 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=self.compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=self.compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -171,6 +179,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compiled=self.compiled,
device=self._device,
)

Expand Down
25 changes: 19 additions & 6 deletions d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig):
:math:`\log{\sum_a \exp{Q(s, a)}}`.
soft_q_backup (bool): Flag to use SAC-style backup.
max_q_backup (bool): Flag to sample max Q-values for target.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

actor_learning_rate: float = 1e-4
Expand Down Expand Up @@ -142,7 +143,6 @@ def inner_create_impl(
assert not (
self._config.soft_q_backup and self._config.max_q_backup
), "soft_q_backup and max_q_backup are mutually exclusive."

policy = create_normal_policy(
observation_shape,
action_size,
Expand Down Expand Up @@ -182,20 +182,28 @@ def inner_create_impl(
)

actor_optim = self._config.actor_optim_factory.create(
policy.named_modules(), lr=self._config.actor_learning_rate
policy.named_modules(),
lr=self._config.actor_learning_rate,
compiled=self.compiled,
)
critic_optim = self._config.critic_optim_factory.create(
q_funcs.named_modules(), lr=self._config.critic_learning_rate
q_funcs.named_modules(),
lr=self._config.critic_learning_rate,
compiled=self.compiled,
)
if self._config.temp_learning_rate > 0:
temp_optim = self._config.temp_optim_factory.create(
log_temp.named_modules(), lr=self._config.temp_learning_rate
log_temp.named_modules(),
lr=self._config.temp_learning_rate,
compiled=self.compiled,
)
else:
temp_optim = None
if self._config.alpha_learning_rate > 0:
alpha_optim = self._config.alpha_optim_factory.create(
log_alpha.named_modules(), lr=self._config.alpha_learning_rate
log_alpha.named_modules(),
lr=self._config.alpha_learning_rate,
compiled=self.compiled,
)
else:
alpha_optim = None
Expand Down Expand Up @@ -225,6 +233,7 @@ def inner_create_impl(
n_action_samples=self._config.n_action_samples,
soft_q_backup=self._config.soft_q_backup,
max_q_backup=self._config.max_q_backup,
compiled=self.compiled,
device=self._device,
)

Expand Down Expand Up @@ -272,6 +281,7 @@ class DiscreteCQLConfig(LearnableConfig):
target_update_interval (int): Interval to synchronize the target
network.
alpha (float): math:`\alpha` value above.
compile_graph (bool): Flag to enable JIT compilation and CUDAGraph.
"""

learning_rate: float = 6.25e-5
Expand Down Expand Up @@ -318,7 +328,9 @@ def inner_create_impl(
)

optim = self._config.optim_factory.create(
q_funcs.named_modules(), lr=self._config.learning_rate
q_funcs.named_modules(),
lr=self._config.learning_rate,
compiled=self.compiled,
)

modules = DQNModules(
Expand All @@ -336,6 +348,7 @@ def inner_create_impl(
target_update_interval=self._config.target_update_interval,
gamma=self._config.gamma,
alpha=self._config.alpha,
compiled=self.compiled,
device=self._device,
)

Expand Down
Loading
Loading