From f6de60283c8f6620916746d2605aee245fdbcee7 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 2 Nov 2024 11:03:15 +0900 Subject: [PATCH 01/15] Add CudaGraphWrapper --- d3rlpy/__init__.py | 3 + d3rlpy/algos/qlearning/awac.py | 3 + d3rlpy/algos/qlearning/bcq.py | 6 ++ d3rlpy/algos/qlearning/bear.py | 3 + d3rlpy/algos/qlearning/cal_ql.py | 2 + d3rlpy/algos/qlearning/cql.py | 6 ++ d3rlpy/algos/qlearning/ddpg.py | 3 + d3rlpy/algos/qlearning/dqn.py | 6 ++ d3rlpy/algos/qlearning/iql.py | 3 + d3rlpy/algos/qlearning/plas.py | 5 + d3rlpy/algos/qlearning/rebrac.py | 3 + d3rlpy/algos/qlearning/sac.py | 3 + d3rlpy/algos/qlearning/td3.py | 3 + d3rlpy/algos/qlearning/td3_plus_bc.py | 3 + d3rlpy/algos/qlearning/torch/awac_impl.py | 4 +- d3rlpy/algos/qlearning/torch/bcq_impl.py | 62 ++++++++---- d3rlpy/algos/qlearning/torch/bear_impl.py | 60 ++++++++---- d3rlpy/algos/qlearning/torch/cql_impl.py | 46 ++++----- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 68 ++++++++----- d3rlpy/algos/qlearning/torch/dqn_impl.py | 31 ++++-- d3rlpy/algos/qlearning/torch/iql_impl.py | 6 +- d3rlpy/algos/qlearning/torch/plas_impl.py | 36 ++++--- d3rlpy/algos/qlearning/torch/rebrac_impl.py | 4 +- d3rlpy/algos/qlearning/torch/sac_impl.py | 12 +-- d3rlpy/algos/qlearning/torch/td3_impl.py | 7 +- .../algos/qlearning/torch/td3_plus_bc_impl.py | 4 +- .../algos/transformer/decision_transformer.py | 13 +-- .../torch/decision_transformer_impl.py | 63 +++++++++--- d3rlpy/models/torch/policies.py | 7 ++ d3rlpy/ope/torch/fqe_impl.py | 2 +- d3rlpy/optimizers/optimizers.py | 8 +- d3rlpy/torch_utility.py | 95 +++++++++++++++++++ d3rlpy/types.py | 11 ++- reproductions/finetuning/cal_ql_finetune.py | 1 + reproductions/offline/cql.py | 12 ++- reproductions/offline/decision_transformer.py | 1 + tests/test_torch_utility.py | 2 +- 37 files changed, 455 insertions(+), 152 deletions(-) diff --git a/d3rlpy/__init__.py b/d3rlpy/__init__.py index c560a39f..db57d410 100644 --- a/d3rlpy/__init__.py +++ b/d3rlpy/__init__.py @@ -68,6 +68,9 @@ def seed(n: int) -> None: # run healthcheck run_healthcheck() +# enable autograd compilation +torch._dynamo.config.compiled_autograd = True +torch.set_float32_matmul_precision("high") # register Shimmy if available try: diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index 86987333..a0a9a72a 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -70,6 +70,7 @@ class AWACConfig(LearnableConfig): n_action_samples (int): Number of sampled actions to calculate :math:`A^\pi(s_t, a_t)`. n_critics (int): Number of Q functions for ensemble. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -85,6 +86,7 @@ class AWACConfig(LearnableConfig): lam: float = 1.0 n_action_samples: int = 1 n_critics: int = 2 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -158,6 +160,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 4bec28d7..a9aed836 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -137,6 +137,7 @@ class BCQConfig(LearnableConfig): rl_start_step (int): Steps to start to update policy function and Q functions. If this is large, RL training would be more stabilized. beta (float): KL reguralization term for Conditional VAE. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-3 @@ -159,6 +160,7 @@ class BCQConfig(LearnableConfig): action_flexibility: float = 0.05 rl_start_step: int = 0 beta: float = 0.5 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -264,6 +266,7 @@ def inner_create_impl( action_flexibility=self._config.action_flexibility, beta=self._config.beta, rl_start_step=self._config.rl_start_step, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) @@ -331,6 +334,7 @@ class DiscreteBCQConfig(LearnableConfig): target_update_interval (int): Interval to update the target network. share_encoder (bool): Flag to share encoder between Q-function and imitation models. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ learning_rate: float = 6.25e-5 @@ -344,6 +348,7 @@ class DiscreteBCQConfig(LearnableConfig): beta: float = 0.5 target_update_interval: int = 8000 share_encoder: bool = True + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -422,6 +427,7 @@ def inner_create_impl( gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 70d83c5c..796cd58f 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -114,6 +114,7 @@ class BEARConfig(LearnableConfig): policy training. warmup_steps (int): Number of steps to warmup the policy function. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-4 @@ -145,6 +146,7 @@ class BEARConfig(LearnableConfig): mmd_sigma: float = 20.0 vae_kl_weight: float = 0.5 warmup_steps: int = 40000 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -266,6 +268,7 @@ def inner_create_impl( mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, warmup_steps=self._config.warmup_steps, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py index 8cc551f2..7b7e9e26 100644 --- a/d3rlpy/algos/qlearning/cal_ql.py +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -69,6 +69,7 @@ class CalQLConfig(CQLConfig): :math:`\log{\sum_a \exp{Q(s, a)}}`. soft_q_backup (bool): Flag to use SAC-style backup. max_q_backup (bool): Flag to sample max Q-values for target. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ def create( @@ -171,6 +172,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, + compile=self._config.compile, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index f36cb501..f12036ce 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -100,6 +100,7 @@ class CQLConfig(LearnableConfig): :math:`\log{\sum_a \exp{Q(s, a)}}`. soft_q_backup (bool): Flag to use SAC-style backup. max_q_backup (bool): Flag to sample max Q-values for target. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-4 @@ -124,6 +125,7 @@ class CQLConfig(LearnableConfig): n_action_samples: int = 10 soft_q_backup: bool = False max_q_backup: bool = False + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -225,6 +227,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) @@ -272,6 +275,7 @@ class DiscreteCQLConfig(LearnableConfig): target_update_interval (int): Interval to synchronize the target network. alpha (float): math:`\alpha` value above. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ learning_rate: float = 6.25e-5 @@ -283,6 +287,7 @@ class DiscreteCQLConfig(LearnableConfig): n_critics: int = 1 target_update_interval: int = 8000 alpha: float = 1.0 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -336,6 +341,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, alpha=self._config.alpha, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 0cff98bb..bd47ef5a 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -69,6 +69,7 @@ class DDPGConfig(LearnableConfig): gamma (float): Discount factor. tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 256 @@ -81,6 +82,7 @@ class DDPGConfig(LearnableConfig): q_func_factory: QFunctionFactory = make_q_func_field() tau: float = 0.005 n_critics: int = 1 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -153,6 +155,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, + compile=self._config.compile, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 0a485ce9..8216ea69 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -44,6 +44,7 @@ class DQNConfig(LearnableConfig): gamma (float): Discount factor. n_critics (int): Number of Q functions for ensemble. target_update_interval (int): Interval to update the target network. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 32 @@ -54,6 +55,7 @@ class DQNConfig(LearnableConfig): gamma: float = 0.99 n_critics: int = 1 target_update_interval: int = 8000 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -106,6 +108,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, modules=modules, gamma=self._config.gamma, + compile=self._config.compile, device=self._device, ) @@ -151,6 +154,7 @@ class DoubleDQNConfig(DQNConfig): n_critics (int): Number of Q functions. target_update_interval (int): Interval to synchronize the target network. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 32 @@ -161,6 +165,7 @@ class DoubleDQNConfig(DQNConfig): gamma: float = 0.99 n_critics: int = 1 target_update_interval: int = 8000 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -213,6 +218,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, + compile=self._config.compile, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 25ec78b9..daf8a278 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -80,6 +80,7 @@ class IQLConfig(LearnableConfig): weight_temp (float): Inverse temperature value represented as :math:`\beta`. max_weight (float): Maximum advantage weight value to clip. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -96,6 +97,7 @@ class IQLConfig(LearnableConfig): expectile: float = 0.7 weight_temp: float = 3.0 max_weight: float = 100.0 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -175,6 +177,7 @@ def inner_create_impl( expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 8e68b32e..acd9b307 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -77,6 +77,7 @@ class PLASConfig(LearnableConfig): lam (float): Weight factor for critic ensemble. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-4 @@ -96,6 +97,7 @@ class PLASConfig(LearnableConfig): lam: float = 0.75 warmup_steps: int = 500000 beta: float = 0.5 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -197,6 +199,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) @@ -247,6 +250,7 @@ class PLASWithPerturbationConfig(PLASConfig): action_flexibility (float): Output scale of perturbation layer. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ action_flexibility: float = 0.05 @@ -373,6 +377,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/rebrac.py b/d3rlpy/algos/qlearning/rebrac.py index 4aabc09f..95371a25 100644 --- a/d3rlpy/algos/qlearning/rebrac.py +++ b/d3rlpy/algos/qlearning/rebrac.py @@ -71,6 +71,7 @@ class ReBRACConfig(LearnableConfig): critic_beta (float): :math:`\beta_2` value. update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-3 @@ -89,6 +90,7 @@ class ReBRACConfig(LearnableConfig): actor_beta: float = 0.001 critic_beta: float = 0.01 update_actor_interval: int = 2 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -166,6 +168,7 @@ def inner_create_impl( actor_beta=self._config.actor_beta, critic_beta=self._config.critic_beta, update_actor_interval=self._config.update_actor_interval, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 5c93c958..e9dc8666 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -94,6 +94,7 @@ class SACConfig(LearnableConfig): tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. initial_temperature (float): Initial temperature value. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -110,6 +111,7 @@ class SACConfig(LearnableConfig): tau: float = 0.005 n_critics: int = 2 initial_temperature: float = 1.0 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -188,6 +190,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 7b54fe8b..9c1642a8 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -74,6 +74,7 @@ class TD3Config(LearnableConfig): target_smoothing_clip (float): Clipping range for target noise. update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -90,6 +91,7 @@ class TD3Config(LearnableConfig): target_smoothing_sigma: float = 0.2 target_smoothing_clip: float = 0.5 update_actor_interval: int = 2 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -165,6 +167,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, + compile=self._config.compile, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index a792ae0d..a71604f8 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -65,6 +65,7 @@ class TD3PlusBCConfig(LearnableConfig): alpha (float): :math:`\alpha` value. update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -82,6 +83,7 @@ class TD3PlusBCConfig(LearnableConfig): target_smoothing_clip: float = 0.5 alpha: float = 2.5 update_actor_interval: int = 2 + compile: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -158,6 +160,7 @@ def inner_create_impl( target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, update_actor_interval=self._config.update_actor_interval, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index b489126c..6a296bbe 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -33,6 +33,7 @@ def __init__( tau: float, lam: float, n_action_samples: int, + compile: bool, device: str, ): super().__init__( @@ -43,13 +44,14 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._lam = lam self._n_action_samples = n_action_samples def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> SACActorLoss: # compute log probability dist = build_gaussian_distribution(action) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index c3e7b49f..0cc60053 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -19,6 +19,7 @@ ) from ....optimizers import OptimizerWrapper from ....torch_utility import ( + CudaGraphWrapper, TorchMiniBatch, expand_and_repeat_recursively, flatten_left_recursively, @@ -69,6 +70,7 @@ def __init__( action_flexibility: float, beta: float, rl_start_step: int, + compile: bool, device: str, ): super().__init__( @@ -79,6 +81,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._lam = lam @@ -86,18 +89,41 @@ def __init__( self._action_flexibility = action_flexibility self._beta = beta self._rl_start_step = rl_start_step + self._compute_imitator_grad = ( + CudaGraphWrapper(self.compute_imitator_grad) + if compile + else self.compute_imitator_grad + ) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: value = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" ) return DDPGBaseActorLoss(-value[0].mean()) - def update_imitator( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: + # forward policy + batch_size = get_batch_size(batch.observations) + latent = torch.randn( + batch_size, 2 * self._action_size, device=self._device + ) + clipped_latent = latent.clamp(-0.5, 0.5) + sampled_action = self._modules.vae_decoder( + x=batch.observations, + latent=clipped_latent, + ) + action = self._modules.policy(batch.observations, sampled_action) + + self._modules.actor_optim.zero_grad() + loss = self.compute_actor_loss(batch, action) + loss.actor_loss.backward() + return loss + + def compute_imitator_grad( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.vae_optim.zero_grad() loss = compute_vae_error( vae_encoder=self._modules.vae_encoder, @@ -107,8 +133,12 @@ def update_imitator( beta=self._beta, ) loss.backward() - self._modules.vae_optim.step(grad_step) - return {"vae_loss": float(loss.cpu().detach().numpy())} + return {"loss": loss} + + def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_imitator_grad(batch) + self._modules.vae_optim.step() + return {"vae_loss": float(loss["loss"].cpu().detach().numpy())} def _repeat_observation(self, x: TorchObservation) -> TorchObservation: # (batch_size, *obs_shape) -> (batch_size, n, *obs_shape) @@ -186,25 +216,13 @@ def inner_update( ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_imitator(batch, grad_step)) + metrics.update(self.update_imitator(batch)) if grad_step < self._rl_start_step: return metrics - # forward policy - batch_size = get_batch_size(batch.observations) - latent = torch.randn( - batch_size, 2 * self._action_size, device=self._device - ) - clipped_latent = latent.clamp(-0.5, 0.5) - sampled_action = self._modules.vae_decoder( - x=batch.observations, - latent=clipped_latent, - ) - action = self._modules.policy(batch.observations, sampled_action) - # update models - metrics.update(self.update_critic(batch, grad_step)) - metrics.update(self.update_actor(batch, action, grad_step)) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) self.update_critic_target() self.update_actor_target() return metrics @@ -237,6 +255,7 @@ def __init__( gamma: float, action_flexibility: float, beta: float, + compile: bool, device: str, ): super().__init__( @@ -247,6 +266,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=target_update_interval, gamma=gamma, + compile=compile, device=device, ) self._action_flexibility = action_flexibility diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 44d8ed2b..8de78ad4 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -17,6 +17,7 @@ ) from ....optimizers import OptimizerWrapper from ....torch_utility import ( + CudaGraphWrapper, TorchMiniBatch, expand_and_repeat_recursively, flatten_left_recursively, @@ -87,6 +88,7 @@ def __init__( mmd_sigma: float, vae_kl_weight: float, warmup_steps: int, + compile: bool, device: str, ): super().__init__( @@ -97,6 +99,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._alpha_threshold = alpha_threshold @@ -108,14 +111,24 @@ def __init__( self._mmd_sigma = mmd_sigma self._vae_kl_weight = vae_kl_weight self._warmup_steps = warmup_steps + self._compute_warmup_actor_grad = ( + CudaGraphWrapper(self.compute_warmup_actor_grad) + if compile + else self.compute_warmup_actor_grad + ) + self._compute_imitator_grad = ( + CudaGraphWrapper(self.compute_imitator_grad) + if compile + else self.compute_imitator_grad + ) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> BEARActorLoss: - loss = super().compute_actor_loss(batch, action, grad_step) + loss = super().compute_actor_loss(batch, action) mmd_loss = self._compute_mmd_loss(batch.observations) if self._modules.alpha_optim: - self.update_alpha(mmd_loss, grad_step) + self.update_alpha(mmd_loss) return BEARActorLoss( actor_loss=loss.actor_loss + mmd_loss, temp_loss=loss.temp_loss, @@ -124,28 +137,36 @@ def compute_actor_loss( alpha=get_parameter(self._modules.log_alpha).exp()[0][0], ) - def warmup_actor( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_warmup_actor_grad( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.actor_optim.zero_grad() loss = self._compute_mmd_loss(batch.observations) loss.backward() - self._modules.actor_optim.step(grad_step) - return {"actor_loss": float(loss.cpu().detach().numpy())} + return {"loss": loss} + + def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_warmup_actor_grad(batch) + self._modules.actor_optim.step() + return {"actor_loss": float(loss["loss"].cpu().detach().numpy())} def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor: mmd = self._compute_mmd(obs_t) alpha = get_parameter(self._modules.log_alpha).exp() return (alpha * (mmd - self._alpha_threshold)).mean() - def update_imitator( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_imitator_grad( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.vae_optim.zero_grad() loss = self.compute_imitator_loss(batch) loss.backward() - self._modules.vae_optim.step(grad_step) - return {"imitator_loss": float(loss.cpu().detach().numpy())} + return {"loss": loss} + + def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_imitator_grad(batch) + self._modules.vae_optim.step() + return {"imitator_loss": float(loss["loss"].cpu().detach().numpy())} def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: return compute_vae_error( @@ -156,12 +177,12 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: beta=self._vae_kl_weight, ) - def update_alpha(self, mmd_loss: torch.Tensor, grad_step: int) -> None: + def update_alpha(self, mmd_loss: torch.Tensor) -> None: assert self._modules.alpha_optim self._modules.alpha_optim.zero_grad() loss = -mmd_loss loss.backward(retain_graph=True) - self._modules.alpha_optim.step(grad_step) + self._modules.alpha_optim.step() # clip for stability get_parameter(self._modules.log_alpha).data.clamp_(-5.0, 10.0) @@ -278,13 +299,12 @@ def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_imitator(batch, grad_step)) - metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_imitator(batch)) + metrics.update(self.update_critic(batch)) if grad_step < self._warmup_steps: - actor_loss = self.warmup_actor(batch, grad_step) + actor_loss = self.warmup_actor(batch) else: - action = self._modules.policy(batch.observations) - actor_loss = self.update_actor(batch, action, grad_step) + actor_loss = self.update_actor(batch) metrics.update(actor_loss) self.update_critic_target() return metrics diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 02d3daf2..f07f75b3 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -60,6 +60,7 @@ def __init__( n_action_samples: int, soft_q_backup: bool, max_q_backup: bool, + compile: bool, device: str, ): super().__init__( @@ -70,6 +71,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._alpha_threshold = alpha_threshold @@ -79,32 +81,38 @@ def __init__( self._max_q_backup = max_q_backup def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int + self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> CQLCriticLoss: - loss = super().compute_critic_loss(batch, q_tpn, grad_step) + loss = super().compute_critic_loss(batch, q_tpn) conservative_loss = self._compute_conservative_loss( obs_t=batch.observations, act_t=batch.actions, obs_tp1=batch.next_observations, returns_to_go=batch.returns_to_go, ) + if self._modules.alpha_optim: - self.update_alpha(conservative_loss, grad_step) + self.update_alpha(conservative_loss.detach()) + + # clip for stability + log_alpha = get_parameter(self._modules.log_alpha) + clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0] + scaled_conservative_loss = clipped_alpha * conservative_loss + return CQLCriticLoss( - critic_loss=loss.critic_loss + conservative_loss.sum(), - conservative_loss=conservative_loss.sum(), - alpha=get_parameter(self._modules.log_alpha).exp()[0][0], + critic_loss=loss.critic_loss + scaled_conservative_loss.sum(), + conservative_loss=scaled_conservative_loss.sum(), + alpha=clipped_alpha, ) - def update_alpha( - self, conservative_loss: torch.Tensor, grad_step: int - ) -> None: + def update_alpha(self, conservative_loss: torch.Tensor) -> None: assert self._modules.alpha_optim self._modules.alpha_optim.zero_grad() - # the original implementation does scale the loss value - loss = -conservative_loss.mean() - loss.backward(retain_graph=True) - self._modules.alpha_optim.step(grad_step) + log_alpha = get_parameter(self._modules.log_alpha) + clipped_alpha = log_alpha.exp().clamp(0, 1e6) + loss = -(clipped_alpha * conservative_loss).mean() + loss.backward() + self._modules.alpha_optim.step() def _compute_policy_is_values( self, @@ -188,15 +196,7 @@ def _compute_conservative_loss( loss = (logsumexp - data_values).mean(dim=[1, 2]) - # clip for stability - log_alpha = get_parameter(self._modules.log_alpha) - clipped_alpha = log_alpha.exp().clamp(0, 1e6)[0][0] - - return ( - clipped_alpha - * self._conservative_weight - * (loss - self._alpha_threshold) - ) + return self._conservative_weight * (loss - self._alpha_threshold) def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: if self._soft_q_backup: @@ -247,6 +247,7 @@ def __init__( target_update_interval: int, gamma: float, alpha: float, + compile: bool, device: str, ): super().__init__( @@ -257,6 +258,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=target_update_interval, gamma=gamma, + compile=compile, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 937b6284..750a1c11 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -1,20 +1,25 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Dict +from typing import Callable, Dict import torch from torch import nn from torch.optim import Optimizer -from d3rlpy.optimizers.optimizers import OptimizerWrapper - from ....dataclass_utils import asdict_as_float from ....models.torch import ( ActionOutput, ContinuousEnsembleQFunctionForwarder, Policy, ) -from ....torch_utility import Modules, TorchMiniBatch, hard_sync, soft_sync +from ....optimizers.optimizers import OptimizerWrapper +from ....torch_utility import ( + CudaGraphWrapper, + Modules, + TorchMiniBatch, + hard_sync, + soft_sync, +) from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase from .utility import ContinuousQFunctionMixin @@ -52,6 +57,8 @@ class DDPGBaseImpl( ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): _modules: DDPGBaseModules + _compute_crtic_grad: Callable[[TorchMiniBatch], DDPGBaseCriticLoss] + _compute_actor_grad: Callable[[TorchMiniBatch], DDPGBaseActorLoss] _gamma: float _tau: float _q_func_forwarder: ContinuousEnsembleQFunctionForwarder @@ -66,6 +73,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, + compile: bool, device: str, ): super().__init__( @@ -78,20 +86,32 @@ def __init__( self._tau = tau self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder + self._compute_critic_grad = ( + CudaGraphWrapper(self.compute_critic_grad) + if compile + else self.compute_critic_grad + ) + self._compute_actor_grad = ( + CudaGraphWrapper(self.compute_actor_grad) + if compile + else self.compute_actor_grad + ) hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) - def update_critic( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_critic_grad(self, batch: TorchMiniBatch) -> DDPGBaseCriticLoss: self._modules.critic_optim.zero_grad() q_tpn = self.compute_target(batch) - loss = self.compute_critic_loss(batch, q_tpn, grad_step) + loss = self.compute_critic_loss(batch, q_tpn) loss.critic_loss.backward() - self._modules.critic_optim.step(grad_step) + return loss + + def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_critic_grad(batch) + self._modules.critic_optim.step() return asdict_as_float(loss) def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int + self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> DDPGBaseCriticLoss: loss = self._q_func_forwarder.compute_error( observations=batch.observations, @@ -103,30 +123,32 @@ def compute_critic_loss( ) return DDPGBaseCriticLoss(loss) - def update_actor( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int - ) -> Dict[str, float]: - # Q function should be inference mode for stability - self._modules.q_funcs.eval() + def compute_actor_grad(self, batch: TorchMiniBatch) -> DDPGBaseActorLoss: + action = self._modules.policy(batch.observations) self._modules.actor_optim.zero_grad() - loss = self.compute_actor_loss(batch, action, grad_step) + loss = self.compute_actor_loss(batch, action) loss.actor_loss.backward() - self._modules.actor_optim.step(grad_step) + return loss + + def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + # Q function should be inference mode for stability + self._modules.q_funcs.eval() + loss = self._compute_actor_grad(batch) + self._modules.actor_optim.step() return asdict_as_float(loss) def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} - action = self._modules.policy(batch.observations) - metrics.update(self.update_critic(batch, grad_step)) - metrics.update(self.update_actor(batch, action, grad_step)) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) self.update_critic_target() return metrics @abstractmethod def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: pass @@ -178,6 +200,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, + compile: bool, device: str, ): super().__init__( @@ -188,12 +211,13 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) hard_sync(self._modules.targ_policy, self._modules.policy) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 68831bed..b0087806 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict +from typing import Callable, Dict import torch from torch import nn @@ -8,7 +8,12 @@ from ....dataclass_utils import asdict_as_float from ....models.torch import DiscreteEnsembleQFunctionForwarder from ....optimizers.optimizers import OptimizerWrapper -from ....torch_utility import Modules, TorchMiniBatch, hard_sync +from ....torch_utility import ( + CudaGraphWrapper, + Modules, + TorchMiniBatch, + hard_sync, +) from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase from .utility import DiscreteQFunctionMixin @@ -30,6 +35,7 @@ class DQNLoss: class DQNImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _modules: DQNModules + _compute_grad: Callable[[TorchMiniBatch], DQNLoss] _gamma: float _q_func_forwarder: DiscreteEnsembleQFunctionForwarder _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder @@ -44,6 +50,7 @@ def __init__( targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, target_update_interval: int, gamma: float, + compile: bool, device: str, ): super().__init__( @@ -56,23 +63,27 @@ def __init__( self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder self._target_update_interval = target_update_interval + self._compute_grad = ( + CudaGraphWrapper(self.compute_grad) # type: ignore + if compile + else self.compute_grad + ) hard_sync(modules.targ_q_funcs, modules.q_funcs) - def inner_update( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss: self._modules.optim.zero_grad() - q_tpn = self.compute_target(batch) - loss = self.compute_loss(batch, q_tpn) - loss.loss.backward() - self._modules.optim.step(grad_step) + return loss + def inner_update( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: + loss = self._compute_grad(batch) + self._modules.optim.step() if grad_step % self._target_update_interval == 0: self.update_target() - return asdict_as_float(loss) def compute_loss( diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index fa1bb445..6d0c0030 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -51,6 +51,7 @@ def __init__( expectile: float, weight_temp: float, max_weight: float, + compile: bool, device: str, ): super().__init__( @@ -61,6 +62,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._expectile = expectile @@ -68,7 +70,7 @@ def __init__( self._max_weight = max_weight def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int + self, batch: TorchMiniBatch, q_tpn: torch.Tensor ) -> IQLCriticLoss: q_loss = self._q_func_forwarder.compute_error( observations=batch.observations, @@ -90,7 +92,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: return self._modules.value_func(batch.next_observations) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: # compute log probability dist = build_gaussian_distribution(action) diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index ccd42c09..970f095f 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -13,7 +13,7 @@ compute_vae_error, ) from ....optimizers import OptimizerWrapper -from ....torch_utility import TorchMiniBatch, soft_sync +from ....torch_utility import CudaGraphWrapper, TorchMiniBatch, soft_sync from ....types import Shape, TorchObservation from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules @@ -52,6 +52,7 @@ def __init__( lam: float, beta: float, warmup_steps: int, + compile: bool, device: str, ): super().__init__( @@ -62,15 +63,21 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._lam = lam self._beta = beta self._warmup_steps = warmup_steps + self._compute_imitator_grad = ( + CudaGraphWrapper(self.compute_imitator_grad) + if compile + else self.compute_imitator_grad + ) - def update_imitator( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_imitator_grad( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.vae_optim.zero_grad() loss = compute_vae_error( vae_encoder=self._modules.vae_encoder, @@ -80,11 +87,15 @@ def update_imitator( beta=self._beta, ) loss.backward() - self._modules.vae_optim.step(grad_step) - return {"vae_loss": float(loss.cpu().detach().numpy())} + return {"loss": loss} + + def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_imitator_grad(batch) + self._modules.vae_optim.step() + return {"vae_loss": float(loss["loss"].cpu().detach().numpy())} def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: latent_actions = 2.0 * action.squashed_mu actions = self._modules.vae_decoder(batch.observations, latent_actions) @@ -125,11 +136,10 @@ def inner_update( metrics = {} if grad_step < self._warmup_steps: - metrics.update(self.update_imitator(batch, grad_step)) + metrics.update(self.update_imitator(batch)) else: - action = self._modules.policy(batch.observations) - metrics.update(self.update_critic(batch, grad_step)) - metrics.update(self.update_actor(batch, action, grad_step)) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) self.update_actor_target() self.update_critic_target() @@ -157,6 +167,7 @@ def __init__( lam: float, beta: float, warmup_steps: int, + compile: bool, device: str, ): super().__init__( @@ -170,11 +181,12 @@ def __init__( lam=lam, beta=beta, warmup_steps=warmup_steps, + compile=compile, device=device, ) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: latent_actions = 2.0 * action.squashed_mu actions = self._modules.vae_decoder(batch.observations, latent_actions) diff --git a/d3rlpy/algos/qlearning/torch/rebrac_impl.py b/d3rlpy/algos/qlearning/torch/rebrac_impl.py index b3480508..b416c441 100644 --- a/d3rlpy/algos/qlearning/torch/rebrac_impl.py +++ b/d3rlpy/algos/qlearning/torch/rebrac_impl.py @@ -29,6 +29,7 @@ def __init__( actor_beta: float, critic_beta: float, update_actor_interval: int, + compile: bool, device: str, ): super().__init__( @@ -42,13 +43,14 @@ def __init__( target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, update_actor_interval=update_actor_interval, + compile=compile, device=device, ) self._actor_beta = actor_beta self._critic_beta = critic_beta def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> TD3PlusBCActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 431644e3..c1726d72 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -59,6 +59,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, + compile: bool, device: str, ): super().__init__( @@ -69,17 +70,18 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> SACActorLoss: dist = build_squashed_gaussian_distribution(action) sampled_action, log_prob = dist.sample_with_log_prob() if self._modules.temp_optim: - temp_loss = self.update_temp(log_prob, grad_step) + temp_loss = self.update_temp(log_prob) else: temp_loss = torch.tensor( 0.0, dtype=torch.float32, device=sampled_action.device @@ -95,16 +97,14 @@ def compute_actor_loss( temp=get_parameter(self._modules.log_temp).exp()[0][0], ) - def update_temp( - self, log_prob: torch.Tensor, grad_step: int - ) -> torch.Tensor: + def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: assert self._modules.temp_optim self._modules.temp_optim.zero_grad() with torch.no_grad(): targ_temp = log_prob - self._action_size loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() loss.backward() - self._modules.temp_optim.step(grad_step) + self._modules.temp_optim.step() return loss def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index c739f026..ef034fe5 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -27,6 +27,7 @@ def __init__( target_smoothing_sigma: float, target_smoothing_clip: float, update_actor_interval: int, + compile: bool, device: str, ): super().__init__( @@ -37,6 +38,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile=compile, device=device, ) self._target_smoothing_sigma = target_smoothing_sigma @@ -65,12 +67,11 @@ def inner_update( ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_critic(batch)) # delayed policy update if grad_step % self._update_actor_interval == 0: - action = self._modules.policy(batch.observations) - metrics.update(self.update_actor(batch, action, grad_step)) + metrics.update(self.update_actor(batch)) self.update_critic_target() self.update_actor_target() diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index c6103e7b..7163e9fb 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -33,6 +33,7 @@ def __init__( target_smoothing_clip: float, alpha: float, update_actor_interval: int, + compile: bool, device: str, ): super().__init__( @@ -46,12 +47,13 @@ def __init__( target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, update_actor_interval=update_actor_interval, + compile=compile, device=device, ) self._alpha = alpha def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> TD3PlusBCActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 030cadb3..55313dfd 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -59,7 +59,7 @@ class DecisionTransformerConfig(TransformerConfig): activation_type (str): Type of activation function. position_encoding_type (d3rlpy.PositionEncodingType): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). - compile (bool): (experimental) Flag to enable JIT compilation. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 64 @@ -111,10 +111,6 @@ def inner_create_impl( transformer.named_modules(), lr=self._config.learning_rate ) - # JIT compile - if self._config.compile: - transformer = torch.compile(transformer, fullgraph=True) - modules = DecisionTransformerModules( transformer=transformer, optim=optim, @@ -125,6 +121,7 @@ def inner_create_impl( action_size=action_size, modules=modules, device=self._device, + compile=self._config.compile and "cuda" in self._device, ) def get_action_type(self) -> ActionSpace: @@ -166,7 +163,7 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). warmup_tokens (int): Number of tokens to warmup learning rate scheduler. final_tokens (int): Final number of tokens for learning rate scheduler. - compile (bool): (experimental) Flag to enable JIT compilation. + compile (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 128 @@ -223,9 +220,6 @@ def inner_create_impl( optim = self._config.optim_factory.create( transformer.named_modules(), lr=self._config.learning_rate ) - # JIT compile - if self._config.compile: - transformer = torch.compile(transformer, fullgraph=True) modules = DiscreteDecisionTransformerModules( transformer=transformer, @@ -239,6 +233,7 @@ def inner_create_impl( warmup_tokens=self._config.warmup_tokens, final_tokens=self._config.final_tokens, initial_learning_rate=self._config.learning_rate, + compile=self._config.compile and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 33624a02..0d6809da 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Dict +from typing import Callable, Dict import torch import torch.nn.functional as F @@ -10,7 +10,11 @@ DiscreteDecisionTransformer, ) from ....optimizers import OptimizerWrapper -from ....torch_utility import Modules, TorchTrajectoryMiniBatch +from ....torch_utility import ( + CudaGraphWrapper, + Modules, + TorchTrajectoryMiniBatch, +) from ....types import Shape from ..base import TransformerAlgoImplBase from ..inputs import TorchTransformerInput @@ -31,6 +35,22 @@ class DecisionTransformerModules(Modules): class DecisionTransformerImpl(TransformerAlgoImplBase): _modules: DecisionTransformerModules + _compute_grad: Callable[[TorchTrajectoryMiniBatch], Dict[str, torch.Tensor]] + + def __init__( + self, + observation_shape: Shape, + action_size: int, + modules: Modules, + compile: bool, + device: str, + ): + super().__init__(observation_shape, action_size, modules, device) + self._compute_grad = ( + CudaGraphWrapper(self.compute_grad) # type: ignore + if compile + else self.compute_grad + ) def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) @@ -40,14 +60,20 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) -> (A,) return action[0][-1] - def inner_update( - self, batch: TorchTrajectoryMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_grad( + self, batch: TorchTrajectoryMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.optim.zero_grad() loss = self.compute_loss(batch) loss.backward() - self._modules.optim.step(grad_step) - return {"loss": float(loss.cpu().detach().numpy())} + return {"loss": loss} + + def inner_update( + self, batch: TorchTrajectoryMiniBatch, grad_step: int + ) -> Dict[str, float]: + loss = self._compute_grad(batch) + self._modules.optim.step() + return {"loss": float(loss["loss"].cpu().detach().numpy())} def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor: action = self._modules.transformer( @@ -73,6 +99,7 @@ class DiscreteDecisionTransformerImpl(TransformerAlgoImplBase): _final_tokens: int _initial_learning_rate: float _tokens: int + _compute_grad: Callable[[TorchTrajectoryMiniBatch], Dict[str, torch.Tensor]] def __init__( self, @@ -82,6 +109,7 @@ def __init__( warmup_tokens: int, final_tokens: int, initial_learning_rate: float, + compile: bool, device: str, ): super().__init__( @@ -93,6 +121,11 @@ def __init__( self._warmup_tokens = warmup_tokens self._final_tokens = final_tokens self._initial_learning_rate = initial_learning_rate + self._compute_grad = ( + CudaGraphWrapper(self.compute_grad) # type: ignore + if compile + else self.compute_grad + ) # TODO: Include stateful information in checkpoint. self._tokens = 0 @@ -104,13 +137,19 @@ def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) -> (A,) return logits[0][-1] - def inner_update( - self, batch: TorchTrajectoryMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_grad( + self, batch: TorchTrajectoryMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.optim.zero_grad() loss = self.compute_loss(batch) loss.backward() - self._modules.optim.step(grad_step) + return {"loss": loss} + + def inner_update( + self, batch: TorchTrajectoryMiniBatch, grad_step: int + ) -> Dict[str, float]: + loss = self._compute_grad(batch) + self._modules.optim.step() # schedule learning rate self._tokens += int(batch.masks.sum().cpu().detach().numpy()) @@ -128,7 +167,7 @@ def inner_update( param_group["lr"] = new_learning_rate return { - "loss": float(loss.cpu().detach().numpy()), + "loss": float(loss["loss"].cpu().detach().numpy()), "learning_rate": new_learning_rate, } diff --git a/d3rlpy/models/torch/policies.py b/d3rlpy/models/torch/policies.py index 3142e7e3..a0815f56 100644 --- a/d3rlpy/models/torch/policies.py +++ b/d3rlpy/models/torch/policies.py @@ -26,6 +26,13 @@ class ActionOutput(NamedTuple): squashed_mu: torch.Tensor logstd: Optional[torch.Tensor] + def copy_(self, src: "ActionOutput") -> None: + self.mu.copy_(src.mu) + self.squashed_mu.copy_(src.squashed_mu) + if self.logstd: + assert src.logstd is not None + self.logstd.copy_(src.logstd) + def build_gaussian_distribution(action: ActionOutput) -> GaussianDistribution: assert action.logstd is not None diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index 9626d693..a1549573 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -111,7 +111,7 @@ def inner_update( self._modules.optim.zero_grad() loss.backward() - self._modules.optim.step(grad_step) + self._modules.optim.step() if grad_step % self._target_update_interval == 0: self.update_target() diff --git a/d3rlpy/optimizers/optimizers.py b/d3rlpy/optimizers/optimizers.py index 82a9bbf6..1719421c 100644 --- a/d3rlpy/optimizers/optimizers.py +++ b/d3rlpy/optimizers/optimizers.py @@ -62,10 +62,10 @@ def __init__( self._clip_grad_norm = clip_grad_norm self._lr_scheduler = lr_scheduler - def zero_grad(self) -> None: - self._optim.zero_grad() + def zero_grad(self, set_to_none: bool = False) -> None: + self._optim.zero_grad(set_to_none=set_to_none) - def step(self, grad_step: int) -> None: + def step(self) -> None: """Updates parameters. Args: @@ -252,6 +252,8 @@ def create_optimizer( eps=self.eps, weight_decay=self.weight_decay, amsgrad=self.amsgrad, + capturable=False, + differentiable=False, ) @staticmethod diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index b16bf26e..07357a7b 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -4,6 +4,7 @@ Any, BinaryIO, Dict, + Generic, Iterator, Optional, Sequence, @@ -17,8 +18,10 @@ import torch import torch.nn.functional as F from torch import nn +from torch.cuda import CUDAGraph from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer +from typing_extensions import Protocol, Self from .dataclass_utils import asdict_without_copy from .dataset import TrajectoryMiniBatch, TransitionMiniBatch @@ -51,6 +54,7 @@ "eval_api", "train_api", "View", + "CudaGraphWrapper", ] @@ -128,6 +132,20 @@ def convert_to_numpy_recursively( raise ValueError(f"invalid array type: {type(array)}") +_T = TypeVar("_T", bound=Union[torch.Tensor, Sequence[torch.Tensor]]) + + +def copy_recursively(src: _T, dst: _T) -> None: + if isinstance(src, torch.Tensor) and isinstance(dst, torch.Tensor): + dst.copy_(src) + elif isinstance(src, (list, tuple)) and isinstance(dst, (list, tuple)): + [d.copy_(s) for s, d in zip(src, dst)] + else: + raise ValueError( + f"invalid inpu types: src={type(src)}, dst={type(dst)}" + ) + + def get_device(x: Union[torch.Tensor, Sequence[torch.Tensor]]) -> str: if isinstance(x, torch.Tensor): return str(x.device) @@ -256,6 +274,17 @@ def from_batch( numpy_batch=batch, ) + def copy_(self, src: Self) -> None: + assert self.device == src.device, "incompatible device" + copy_recursively(src.observations, self.observations) + self.actions.copy_(src.actions) + self.rewards.copy_(src.rewards) + copy_recursively(src.next_observations, self.next_observations) + self.next_actions.copy_(src.next_actions) + self.returns_to_go.copy_(src.returns_to_go) + self.terminals.copy_(src.terminals) + self.intervals.copy_(src.intervals) + @dataclasses.dataclass(frozen=True) class TorchTrajectoryMiniBatch: @@ -309,6 +338,16 @@ def from_batch( numpy_batch=batch, ) + def copy_(self, src: Self) -> None: + assert self.device == src.device, "incompatible device" + copy_recursively(src.observations, self.observations) + self.actions.copy_(src.actions) + self.rewards.copy_(src.rewards) + self.returns_to_go.copy_(src.returns_to_go) + self.terminals.copy_(src.terminals) + self.timesteps.copy_(src.timesteps) + self.masks.copy_(src.masks) + _TModule = TypeVar("_TModule", bound=nn.Module) @@ -458,3 +497,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: assert x.shape[-1] % 2 == 0 a, b = x.chunk(2, dim=-1) return a * F.gelu(b) + + +BatchT = TypeVar( + "BatchT", + bound=Union[TorchMiniBatch, TorchTrajectoryMiniBatch], + contravariant=True, +) +RetT = TypeVar("RetT", covariant=True) + + +class CudaGraphFunc(Generic[BatchT, RetT], Protocol): + def __call__(self, batch: BatchT) -> RetT: ... + + +class CudaGraphWrapper(Generic[BatchT, RetT]): + _func: CudaGraphFunc[BatchT, RetT] + _input: TorchTrajectoryMiniBatch + _graph: Optional[CUDAGraph] + _inpt: Optional[BatchT] + _out: Optional[RetT] + + def __init__( + self, + func: CudaGraphFunc[BatchT, RetT], + warmup_steps: int = 3, + compile: bool = True, + ): + self._func = torch.compile(func) if compile else func + self._step = 0 + self._graph = None + self._inpt = None + self._out = None + self._warmup_steps = warmup_steps + self._warmup_stream = torch.cuda.Stream() + + def __call__(self, batch: BatchT) -> RetT: + if self._step < self._warmup_steps: # warmup + self._warmup_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._warmup_stream): + out = self._func(batch) + torch.cuda.current_stream().wait_stream(self._warmup_stream) + if self._step == self._warmup_steps - 1: # build graph + self._graph = torch.cuda.CUDAGraph() + self._inpt = batch + with torch.cuda.graph(self._graph): + self._out = self._func(self._inpt) + if self._step >= self._warmup_steps: # reuse cuda graph + assert self._inpt + assert self._out + assert self._graph + with torch.no_grad(): + self._inpt.copy_(batch) # type: ignore + self._graph.replay() + out = self._out + self._step += 1 + return out diff --git a/d3rlpy/types.py b/d3rlpy/types.py index 2d532c34..e1e50342 100644 --- a/d3rlpy/types.py +++ b/d3rlpy/types.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence, Union +from typing import Any, Sequence, Type, TypeVar, Union import gym import gymnasium @@ -20,6 +20,7 @@ "TorchObservation", "GymEnv", "OptimizerWrapperProto", + "assert_cast", ] @@ -42,3 +43,11 @@ class OptimizerWrapperProto(Protocol): @property def optim(self) -> Optimizer: raise NotImplementedError + + +T = TypeVar("T") + + +def assert_cast(obj_type: Type[T], obj: Any) -> T: + assert isinstance(obj, obj_type) + return obj diff --git a/reproductions/finetuning/cal_ql_finetune.py b/reproductions/finetuning/cal_ql_finetune.py index 35a492d2..d937c600 100644 --- a/reproductions/finetuning/cal_ql_finetune.py +++ b/reproductions/finetuning/cal_ql_finetune.py @@ -51,6 +51,7 @@ def main() -> None: alpha_threshold=0.8, reward_scaler=reward_scaler, max_q_backup=True, + compile=True, ).create(device=args.gpu) # pretraining diff --git a/reproductions/offline/cql.py b/reproductions/offline/cql.py index 734e5e11..29562508 100644 --- a/reproductions/offline/cql.py +++ b/reproductions/offline/cql.py @@ -1,4 +1,5 @@ import argparse +import math import d3rlpy @@ -16,10 +17,10 @@ def main() -> None: d3rlpy.seed(args.seed) d3rlpy.envs.seed_env(env, args.seed) - encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256]) + encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256]) if "medium-v0" in args.dataset: - conservative_weight = 10.0 + conservative_weight = 5.0 else: conservative_weight = 5.0 @@ -27,12 +28,15 @@ def main() -> None: actor_learning_rate=1e-4, critic_learning_rate=3e-4, temp_learning_rate=1e-4, + alpha_learning_rate=3e-4, + initial_alpha=math.e, actor_encoder_factory=encoder, critic_encoder_factory=encoder, batch_size=256, n_action_samples=10, - alpha_learning_rate=0.0, - conservative_weight=conservative_weight, + alpha_threshold=10, + conservative_weight=5.0, + compile=True, ).create(device=args.gpu) cql.fit( diff --git a/reproductions/offline/decision_transformer.py b/reproductions/offline/decision_transformer.py index f3849f86..4cfb4a97 100644 --- a/reproductions/offline/decision_transformer.py +++ b/reproductions/offline/decision_transformer.py @@ -46,6 +46,7 @@ def main() -> None: num_heads=1, num_layers=3, max_timestep=1000, + compile=False, ).create(device=args.gpu) dt.fit( diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index fa739784..9f96484e 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -145,7 +145,7 @@ def test_reset_optimizer_states() -> None: # instantiate optimizer state y = impl.fc1(torch.rand(100)).sum() y.backward() - impl.optim.step(0) + impl.optim.step() # check if state is not empty state = copy.deepcopy(impl.optim.optim.state) From 52097073643bdbc4f8f7525a5edc682b3932bcf8 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 2 Nov 2024 14:56:06 +0900 Subject: [PATCH 02/15] Fix lint errors --- d3rlpy/__init__.py | 8 +- d3rlpy/algos/qlearning/awac.py | 6 +- d3rlpy/algos/qlearning/bcq.py | 12 +-- d3rlpy/algos/qlearning/bear.py | 6 +- d3rlpy/algos/qlearning/cal_ql.py | 4 +- d3rlpy/algos/qlearning/cql.py | 12 +-- d3rlpy/algos/qlearning/crr.py | 3 + d3rlpy/algos/qlearning/ddpg.py | 6 +- d3rlpy/algos/qlearning/dqn.py | 12 +-- d3rlpy/algos/qlearning/iql.py | 6 +- d3rlpy/algos/qlearning/nfq.py | 3 + d3rlpy/algos/qlearning/plas.py | 10 +- d3rlpy/algos/qlearning/rebrac.py | 6 +- d3rlpy/algos/qlearning/sac.py | 9 +- d3rlpy/algos/qlearning/td3.py | 6 +- d3rlpy/algos/qlearning/td3_plus_bc.py | 6 +- d3rlpy/algos/qlearning/torch/awac_impl.py | 4 +- d3rlpy/algos/qlearning/torch/bc_impl.py | 8 +- d3rlpy/algos/qlearning/torch/bcq_impl.py | 15 +-- d3rlpy/algos/qlearning/torch/bear_impl.py | 18 ++-- d3rlpy/algos/qlearning/torch/cql_impl.py | 8 +- d3rlpy/algos/qlearning/torch/crr_impl.py | 9 +- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 16 +-- d3rlpy/algos/qlearning/torch/dqn_impl.py | 4 +- d3rlpy/algos/qlearning/torch/iql_impl.py | 4 +- d3rlpy/algos/qlearning/torch/plas_impl.py | 15 +-- d3rlpy/algos/qlearning/torch/rebrac_impl.py | 4 +- d3rlpy/algos/qlearning/torch/sac_impl.py | 99 +++++++++++-------- d3rlpy/algos/qlearning/torch/td3_impl.py | 4 +- .../algos/qlearning/torch/td3_plus_bc_impl.py | 4 +- .../algos/transformer/decision_transformer.py | 14 ++- .../torch/decision_transformer_impl.py | 8 +- d3rlpy/optimizers/optimizers.py | 2 - d3rlpy/torch_utility.py | 33 ++++--- reproductions/finetuning/cal_ql_finetune.py | 1 - reproductions/offline/cql.py | 5 +- reproductions/offline/decision_transformer.py | 1 - 37 files changed, 209 insertions(+), 182 deletions(-) diff --git a/d3rlpy/__init__.py b/d3rlpy/__init__.py index db57d410..6a73196a 100644 --- a/d3rlpy/__init__.py +++ b/d3rlpy/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=protected-access import random import gymnasium @@ -68,9 +69,10 @@ def seed(n: int) -> None: # run healthcheck run_healthcheck() -# enable autograd compilation -torch._dynamo.config.compiled_autograd = True -torch.set_float32_matmul_precision("high") +if torch.cuda.is_available(): + # enable autograd compilation + torch._dynamo.config.compiled_autograd = True + torch.set_float32_matmul_precision("high") # register Shimmy if available try: diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index a0a9a72a..03e93705 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -70,7 +70,7 @@ class AWACConfig(LearnableConfig): n_action_samples (int): Number of sampled actions to calculate :math:`A^\pi(s_t, a_t)`. n_critics (int): Number of Q functions for ensemble. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -86,7 +86,7 @@ class AWACConfig(LearnableConfig): lam: float = 1.0 n_action_samples: int = 1 n_critics: int = 2 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -160,7 +160,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index a9aed836..810501fd 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -137,7 +137,7 @@ class BCQConfig(LearnableConfig): rl_start_step (int): Steps to start to update policy function and Q functions. If this is large, RL training would be more stabilized. beta (float): KL reguralization term for Conditional VAE. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-3 @@ -160,7 +160,7 @@ class BCQConfig(LearnableConfig): action_flexibility: float = 0.05 rl_start_step: int = 0 beta: float = 0.5 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -266,7 +266,7 @@ def inner_create_impl( action_flexibility=self._config.action_flexibility, beta=self._config.beta, rl_start_step=self._config.rl_start_step, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) @@ -334,7 +334,7 @@ class DiscreteBCQConfig(LearnableConfig): target_update_interval (int): Interval to update the target network. share_encoder (bool): Flag to share encoder between Q-function and imitation models. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ learning_rate: float = 6.25e-5 @@ -348,7 +348,7 @@ class DiscreteBCQConfig(LearnableConfig): beta: float = 0.5 target_update_interval: int = 8000 share_encoder: bool = True - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -427,7 +427,7 @@ def inner_create_impl( gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 796cd58f..034200b7 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -114,7 +114,7 @@ class BEARConfig(LearnableConfig): policy training. warmup_steps (int): Number of steps to warmup the policy function. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-4 @@ -146,7 +146,7 @@ class BEARConfig(LearnableConfig): mmd_sigma: float = 20.0 vae_kl_weight: float = 0.5 warmup_steps: int = 40000 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -268,7 +268,7 @@ def inner_create_impl( mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, warmup_steps=self._config.warmup_steps, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py index 7b7e9e26..a6fd2a40 100644 --- a/d3rlpy/algos/qlearning/cal_ql.py +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -69,7 +69,7 @@ class CalQLConfig(CQLConfig): :math:`\log{\sum_a \exp{Q(s, a)}}`. soft_q_backup (bool): Flag to use SAC-style backup. max_q_backup (bool): Flag to sample max Q-values for target. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ def create( @@ -172,7 +172,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile=self._config.compile, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index f12036ce..526f9c94 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -100,7 +100,7 @@ class CQLConfig(LearnableConfig): :math:`\log{\sum_a \exp{Q(s, a)}}`. soft_q_backup (bool): Flag to use SAC-style backup. max_q_backup (bool): Flag to sample max Q-values for target. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-4 @@ -125,7 +125,7 @@ class CQLConfig(LearnableConfig): n_action_samples: int = 10 soft_q_backup: bool = False max_q_backup: bool = False - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -227,7 +227,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) @@ -275,7 +275,7 @@ class DiscreteCQLConfig(LearnableConfig): target_update_interval (int): Interval to synchronize the target network. alpha (float): math:`\alpha` value above. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ learning_rate: float = 6.25e-5 @@ -287,7 +287,7 @@ class DiscreteCQLConfig(LearnableConfig): n_critics: int = 1 target_update_interval: int = 8000 alpha: float = 1.0 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -341,7 +341,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, alpha=self._config.alpha, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index 6fa9828f..46b3b981 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -99,6 +99,7 @@ class CRRConfig(LearnableConfig): ``soft`` target update. update_actor_interval (int): Interval to update policy function used with ``hard`` target update. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -120,6 +121,7 @@ class CRRConfig(LearnableConfig): tau: float = 5e-3 target_update_interval: int = 100 update_actor_interval: int = 1 + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -199,6 +201,7 @@ def inner_create_impl( tau=self._config.tau, target_update_type=self._config.target_update_type, target_update_interval=self._config.target_update_interval, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index bd47ef5a..3144d160 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -69,7 +69,7 @@ class DDPGConfig(LearnableConfig): gamma (float): Discount factor. tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 256 @@ -82,7 +82,7 @@ class DDPGConfig(LearnableConfig): q_func_factory: QFunctionFactory = make_q_func_field() tau: float = 0.005 n_critics: int = 1 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -155,7 +155,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile=self._config.compile, + compile_graph=self._config.compile_graph, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 8216ea69..a4a7e8f7 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -44,7 +44,7 @@ class DQNConfig(LearnableConfig): gamma (float): Discount factor. n_critics (int): Number of Q functions for ensemble. target_update_interval (int): Interval to update the target network. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 32 @@ -55,7 +55,7 @@ class DQNConfig(LearnableConfig): gamma: float = 0.99 n_critics: int = 1 target_update_interval: int = 8000 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -108,7 +108,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, modules=modules, gamma=self._config.gamma, - compile=self._config.compile, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) @@ -154,7 +154,7 @@ class DoubleDQNConfig(DQNConfig): n_critics (int): Number of Q functions. target_update_interval (int): Interval to synchronize the target network. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 32 @@ -165,7 +165,7 @@ class DoubleDQNConfig(DQNConfig): gamma: float = 0.99 n_critics: int = 1 target_update_interval: int = 8000 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -218,7 +218,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile=self._config.compile, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index daf8a278..a95b79f2 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -80,7 +80,7 @@ class IQLConfig(LearnableConfig): weight_temp (float): Inverse temperature value represented as :math:`\beta`. max_weight (float): Maximum advantage weight value to clip. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -97,7 +97,7 @@ class IQLConfig(LearnableConfig): expectile: float = 0.7 weight_temp: float = 3.0 max_weight: float = 100.0 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -177,7 +177,7 @@ def inner_create_impl( expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 7280dace..9855fa1c 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -47,6 +47,7 @@ class NFQConfig(LearnableConfig): batch_size (int): Mini-batch size. gamma (float): Discount factor. n_critics (int): Number of Q functions for ensemble. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ learning_rate: float = 6.25e-5 @@ -56,6 +57,7 @@ class NFQConfig(LearnableConfig): batch_size: int = 32 gamma: float = 0.99 n_critics: int = 1 + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -108,6 +110,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=1, gamma=self._config.gamma, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index acd9b307..e5abc75a 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -77,7 +77,7 @@ class PLASConfig(LearnableConfig): lam (float): Weight factor for critic ensemble. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-4 @@ -97,7 +97,7 @@ class PLASConfig(LearnableConfig): lam: float = 0.75 warmup_steps: int = 500000 beta: float = 0.5 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -199,7 +199,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) @@ -250,7 +250,7 @@ class PLASWithPerturbationConfig(PLASConfig): action_flexibility (float): Output scale of perturbation layer. warmup_steps (int): Number of steps to warmup the VAE. beta (float): KL reguralization term for Conditional VAE. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ action_flexibility: float = 0.05 @@ -377,7 +377,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/rebrac.py b/d3rlpy/algos/qlearning/rebrac.py index 95371a25..cc237910 100644 --- a/d3rlpy/algos/qlearning/rebrac.py +++ b/d3rlpy/algos/qlearning/rebrac.py @@ -71,7 +71,7 @@ class ReBRACConfig(LearnableConfig): critic_beta (float): :math:`\beta_2` value. update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 1e-3 @@ -90,7 +90,7 @@ class ReBRACConfig(LearnableConfig): actor_beta: float = 0.001 critic_beta: float = 0.01 update_actor_interval: int = 2 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -168,7 +168,7 @@ def inner_create_impl( actor_beta=self._config.actor_beta, critic_beta=self._config.critic_beta, update_actor_interval=self._config.update_actor_interval, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index e9dc8666..71d8b8e9 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -94,7 +94,7 @@ class SACConfig(LearnableConfig): tau (float): Target network synchronization coefficiency. n_critics (int): Number of Q functions for ensemble. initial_temperature (float): Initial temperature value. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -111,7 +111,7 @@ class SACConfig(LearnableConfig): tau: float = 0.005 n_critics: int = 2 initial_temperature: float = 1.0 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -190,7 +190,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) @@ -252,6 +252,7 @@ class DiscreteSACConfig(LearnableConfig): gamma (float): Discount factor. n_critics (int): Number of Q functions for ensemble. initial_temperature (float): Initial temperature value. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -268,6 +269,7 @@ class DiscreteSACConfig(LearnableConfig): n_critics: int = 2 initial_temperature: float = 1.0 target_update_interval: int = 8000 + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -350,6 +352,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 9c1642a8..ddad6eb0 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -74,7 +74,7 @@ class TD3Config(LearnableConfig): target_smoothing_clip (float): Clipping range for target noise. update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -91,7 +91,7 @@ class TD3Config(LearnableConfig): target_smoothing_sigma: float = 0.2 target_smoothing_clip: float = 0.5 update_actor_interval: int = 2 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -167,7 +167,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, - compile=self._config.compile, + compile_graph=self._config.compile_graph, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index a71604f8..bdbda4b6 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -65,7 +65,7 @@ class TD3PlusBCConfig(LearnableConfig): alpha (float): :math:`\alpha` value. update_actor_interval (int): Interval to update policy function described as `delayed policy update` in the paper. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ actor_learning_rate: float = 3e-4 @@ -83,7 +83,7 @@ class TD3PlusBCConfig(LearnableConfig): target_smoothing_clip: float = 0.5 alpha: float = 2.5 update_actor_interval: int = 2 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -160,7 +160,7 @@ def inner_create_impl( target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, update_actor_interval=self._config.update_actor_interval, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index 6a296bbe..aad6b7b2 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -33,7 +33,7 @@ def __init__( tau: float, lam: float, n_action_samples: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -44,7 +44,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._lam = lam diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 14ac9f6e..8680ef5b 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -47,15 +47,13 @@ def __init__( device=device, ) - def update_imitator( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: self._modules.optim.zero_grad() loss = self.compute_loss(batch.observations, batch.actions) loss.loss.backward() - self._modules.optim.step(grad_step) + self._modules.optim.step() return asdict_as_float(loss) @@ -76,7 +74,7 @@ def inner_predict_value( def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: - return self.update_imitator(batch, grad_step) + return self.update_imitator(batch) @dataclasses.dataclass(frozen=True) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 0cc60053..b3033508 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -1,6 +1,6 @@ import dataclasses import math -from typing import Dict, cast +from typing import Callable, Dict, cast import torch import torch.nn.functional as F @@ -50,6 +50,7 @@ class BCQModules(DDPGBaseModules): class BCQImpl(DDPGBaseImpl): _modules: BCQModules + _compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] _lam: float _n_action_samples: int _action_flexibility: float @@ -70,7 +71,7 @@ def __init__( action_flexibility: float, beta: float, rl_start_step: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -81,7 +82,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._lam = lam @@ -90,8 +91,8 @@ def __init__( self._beta = beta self._rl_start_step = rl_start_step self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) - if compile + CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + if compile_graph else self.compute_imitator_grad ) @@ -255,7 +256,7 @@ def __init__( gamma: float, action_flexibility: float, beta: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -266,7 +267,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=target_update_interval, gamma=gamma, - compile=compile, + compile_graph=compile_graph, device=device, ) self._action_flexibility = action_flexibility diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 8de78ad4..3358d487 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict, Optional +from typing import Callable, Dict, Optional import torch @@ -60,6 +60,10 @@ class BEARActorLoss(SACActorLoss): class BEARImpl(SACImpl): _modules: BEARModules + _compute_warmup_actor_grad: Callable[ + [TorchMiniBatch], Dict[str, torch.Tensor] + ] + _compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] _alpha_threshold: float _lam: float _n_action_samples: int @@ -88,7 +92,7 @@ def __init__( mmd_sigma: float, vae_kl_weight: float, warmup_steps: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -99,7 +103,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._alpha_threshold = alpha_threshold @@ -112,13 +116,13 @@ def __init__( self._vae_kl_weight = vae_kl_weight self._warmup_steps = warmup_steps self._compute_warmup_actor_grad = ( - CudaGraphWrapper(self.compute_warmup_actor_grad) - if compile + CudaGraphWrapper(self.compute_warmup_actor_grad) # type: ignore + if compile_graph else self.compute_warmup_actor_grad ) self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) - if compile + CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index f07f75b3..e2c6753a 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -60,7 +60,7 @@ def __init__( n_action_samples: int, soft_q_backup: bool, max_q_backup: bool, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -71,7 +71,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._alpha_threshold = alpha_threshold @@ -247,7 +247,7 @@ def __init__( target_update_interval: int, gamma: float, alpha: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -258,7 +258,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=target_update_interval, gamma=gamma, - compile=compile, + compile_graph=compile_graph, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index b726c9ff..2e2ccd1b 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -55,6 +55,7 @@ def __init__( tau: float, target_update_type: str, target_update_interval: int, + compile_graph: bool, device: str, ): super().__init__( @@ -65,6 +66,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, + compile_graph=compile_graph, device=device, ) self._beta = beta @@ -76,7 +78,7 @@ def __init__( self._target_update_interval = target_update_interval def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int + self, batch: TorchMiniBatch, action: ActionOutput ) -> DDPGBaseActorLoss: # compute log probability dist = build_gaussian_distribution(action) @@ -186,9 +188,8 @@ def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} - action = self._modules.policy(batch.observations) - metrics.update(self.update_critic(batch, grad_step)) - metrics.update(self.update_actor(batch, action, grad_step)) + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) if self._target_update_type == "hard": if grad_step % self._target_update_interval == 0: diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 750a1c11..3d88dcfd 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -57,7 +57,7 @@ class DDPGBaseImpl( ContinuousQFunctionMixin, QLearningAlgoImplBase, metaclass=ABCMeta ): _modules: DDPGBaseModules - _compute_crtic_grad: Callable[[TorchMiniBatch], DDPGBaseCriticLoss] + _compute_critic_grad: Callable[[TorchMiniBatch], DDPGBaseCriticLoss] _compute_actor_grad: Callable[[TorchMiniBatch], DDPGBaseActorLoss] _gamma: float _tau: float @@ -73,7 +73,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -87,13 +87,13 @@ def __init__( self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder self._compute_critic_grad = ( - CudaGraphWrapper(self.compute_critic_grad) - if compile + CudaGraphWrapper(self.compute_critic_grad) # type: ignore + if compile_graph else self.compute_critic_grad ) self._compute_actor_grad = ( - CudaGraphWrapper(self.compute_actor_grad) - if compile + CudaGraphWrapper(self.compute_actor_grad) # type: ignore + if compile_graph else self.compute_actor_grad ) hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) @@ -200,7 +200,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -211,7 +211,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) hard_sync(self._modules.targ_policy, self._modules.policy) diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index b0087806..0ecfaee9 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -50,7 +50,7 @@ def __init__( targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, target_update_interval: int, gamma: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -65,7 +65,7 @@ def __init__( self._target_update_interval = target_update_interval self._compute_grad = ( CudaGraphWrapper(self.compute_grad) # type: ignore - if compile + if compile_graph else self.compute_grad ) hard_sync(modules.targ_q_funcs, modules.q_funcs) diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 6d0c0030..76c55086 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -51,7 +51,7 @@ def __init__( expectile: float, weight_temp: float, max_weight: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -62,7 +62,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._expectile = expectile diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 970f095f..71620cdf 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Dict +from typing import Callable, Dict import torch @@ -36,6 +36,7 @@ class PLASModules(DDPGBaseModules): class PLASImpl(DDPGBaseImpl): _modules: PLASModules + _compute_imitator_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] _lam: float _beta: float _warmup_steps: int @@ -52,7 +53,7 @@ def __init__( lam: float, beta: float, warmup_steps: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -63,15 +64,15 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._lam = lam self._beta = beta self._warmup_steps = warmup_steps self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) - if compile + CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + if compile_graph else self.compute_imitator_grad ) @@ -167,7 +168,7 @@ def __init__( lam: float, beta: float, warmup_steps: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -181,7 +182,7 @@ def __init__( lam=lam, beta=beta, warmup_steps=warmup_steps, - compile=compile, + compile_graph=compile_graph, device=device, ) diff --git a/d3rlpy/algos/qlearning/torch/rebrac_impl.py b/d3rlpy/algos/qlearning/torch/rebrac_impl.py index b416c441..ba4af2be 100644 --- a/d3rlpy/algos/qlearning/torch/rebrac_impl.py +++ b/d3rlpy/algos/qlearning/torch/rebrac_impl.py @@ -29,7 +29,7 @@ def __init__( actor_beta: float, critic_beta: float, update_actor_interval: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -43,7 +43,7 @@ def __init__( target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, update_actor_interval=update_actor_interval, - compile=compile, + compile_graph=compile_graph, device=device, ) self._actor_beta = actor_beta diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index c1726d72..57ab0b39 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -1,10 +1,11 @@ import dataclasses import math -from typing import Dict, Optional +from typing import Callable, Dict, Optional import torch import torch.nn.functional as F from torch import nn +from torch.distributions import Categorical from torch.optim import Optimizer from ....models.torch import ( @@ -19,7 +20,12 @@ get_parameter, ) from ....optimizers import OptimizerWrapper -from ....torch_utility import Modules, TorchMiniBatch, hard_sync +from ....torch_utility import ( + CudaGraphWrapper, + Modules, + TorchMiniBatch, + hard_sync, +) from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase from .ddpg_impl import DDPGBaseActorLoss, DDPGBaseImpl, DDPGBaseModules @@ -59,7 +65,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -70,7 +76,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) @@ -142,6 +148,8 @@ class DiscreteSACImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): _q_func_forwarder: DiscreteEnsembleQFunctionForwarder _targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder _target_update_interval: int + _compute_critic_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] + _compute_actor_grad: Callable[[TorchMiniBatch], Dict[str, torch.Tensor]] def __init__( self, @@ -152,6 +160,7 @@ def __init__( targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, target_update_interval: int, gamma: float, + compile_graph: bool, device: str, ): super().__init__( @@ -164,20 +173,31 @@ def __init__( self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder self._target_update_interval = target_update_interval + self._compute_critic_grad = ( + CudaGraphWrapper(self.compute_critic_grad) # type: ignore + if compile_graph + else self.compute_critic_grad + ) + self._compute_actor_grad = ( + CudaGraphWrapper(self.compute_actor_grad) # type: ignore + if compile_graph + else self.compute_actor_grad + ) hard_sync(modules.targ_q_funcs, modules.q_funcs) - def update_critic( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + def compute_critic_grad( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.critic_optim.zero_grad() - q_tpn = self.compute_target(batch) loss = self.compute_critic_loss(batch, q_tpn) - loss.backward() - self._modules.critic_optim.step(grad_step) + return {"loss": loss} - return {"critic_loss": float(loss.cpu().detach().numpy())} + def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_critic_grad(batch) + self._modules.critic_optim.step() + return {"critic_loss": float(loss["loss"].cpu().detach().numpy())} def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: with torch.no_grad(): @@ -213,27 +233,34 @@ def compute_critic_loss( gamma=self._gamma**batch.intervals, ) - def update_actor( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: - # Q function should be inference mode for stability - self._modules.q_funcs.eval() - + def compute_actor_grad( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: self._modules.actor_optim.zero_grad() - loss = self.compute_actor_loss(batch) + loss["loss"].backward() + return loss - loss.backward() - self._modules.actor_optim.step(grad_step) - - return {"actor_loss": float(loss.cpu().detach().numpy())} + def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + # Q function should be inference mode for stability + self._modules.q_funcs.eval() + loss = self._compute_critic_grad(batch) + self._modules.actor_optim.step() + return {"actor_loss": float(loss["loss"].cpu().detach().numpy())} - def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: + def compute_actor_loss( + self, batch: TorchMiniBatch + ) -> Dict[str, torch.Tensor]: with torch.no_grad(): q_t = self._q_func_forwarder.compute_expected_q( batch.observations, reduction="min" ) dist = self._modules.policy(batch.observations) + + loss = {} + if self._modules.temp_optim: + loss.update(self.update_temp(batch, dist)) + log_probs = dist.logits probs = dist.probs if self._modules.log_temp is None: @@ -241,17 +268,17 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: else: temp = get_parameter(self._modules.log_temp).exp() entropy = temp * log_probs - return (probs * (entropy - q_t)).sum(dim=1).mean() + loss["loss"] = (probs * (entropy - q_t)).sum(dim=1).mean() + return loss def update_temp( - self, batch: TorchMiniBatch, grad_step: int - ) -> Dict[str, float]: + self, batch: TorchMiniBatch, dist: Categorical + ) -> Dict[str, torch.Tensor]: assert self._modules.temp_optim assert self._modules.log_temp is not None self._modules.temp_optim.zero_grad() with torch.no_grad(): - dist = self._modules.policy(batch.observations) log_probs = F.log_softmax(dist.logits, dim=1) probs = dist.probs expct_log_probs = (probs * log_probs).sum(dim=1, keepdim=True) @@ -261,31 +288,21 @@ def update_temp( loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() loss.backward() - self._modules.temp_optim.step(grad_step) + self._modules.temp_optim.step() # current temperature value log_temp = get_parameter(self._modules.log_temp) - cur_temp = log_temp.exp().cpu().detach().numpy()[0][0] - return { - "temp_loss": float(loss.cpu().detach().numpy()), - "temp": float(cur_temp), - } + return {"temp_loss": loss, "temp": log_temp.exp()[0][0]} def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} - - # lagrangian parameter update for SAC temeprature - if self._modules.temp_optim: - metrics.update(self.update_temp(batch, grad_step)) - metrics.update(self.update_critic(batch, grad_step)) - metrics.update(self.update_actor(batch, grad_step)) - + metrics.update(self.update_critic(batch)) + metrics.update(self.update_actor(batch)) if grad_step % self._target_update_interval == 0: self.update_target() - return metrics def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index ef034fe5..9d9b08fc 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -27,7 +27,7 @@ def __init__( target_smoothing_sigma: float, target_smoothing_clip: float, update_actor_interval: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -38,7 +38,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile=compile, + compile_graph=compile_graph, device=device, ) self._target_smoothing_sigma = target_smoothing_sigma diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 7163e9fb..1d844b22 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -33,7 +33,7 @@ def __init__( target_smoothing_clip: float, alpha: float, update_actor_interval: int, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -47,7 +47,7 @@ def __init__( target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, update_actor_interval=update_actor_interval, - compile=compile, + compile_graph=compile_graph, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 55313dfd..faba155e 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -1,7 +1,5 @@ import dataclasses -import torch - from ...base import DeviceArg, register_learnable from ...constants import ActionSpace, PositionEncodingType from ...models import EncoderFactory, make_encoder_field @@ -59,7 +57,7 @@ class DecisionTransformerConfig(TransformerConfig): activation_type (str): Type of activation function. position_encoding_type (d3rlpy.PositionEncodingType): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 64 @@ -73,7 +71,7 @@ class DecisionTransformerConfig(TransformerConfig): embed_dropout: float = 0.1 activation_type: str = "relu" position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -121,7 +119,7 @@ def inner_create_impl( action_size=action_size, modules=modules, device=self._device, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, ) def get_action_type(self) -> ActionSpace: @@ -163,7 +161,7 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). warmup_tokens (int): Number of tokens to warmup learning rate scheduler. final_tokens (int): Final number of tokens for learning rate scheduler. - compile (bool): Flag to enable JIT compilation and CUDAGraph. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 128 @@ -180,7 +178,7 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): position_encoding_type: PositionEncodingType = PositionEncodingType.GLOBAL warmup_tokens: int = 10240 final_tokens: int = 30000000 - compile: bool = False + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -233,7 +231,7 @@ def inner_create_impl( warmup_tokens=self._config.warmup_tokens, final_tokens=self._config.final_tokens, initial_learning_rate=self._config.learning_rate, - compile=self._config.compile and "cuda" in self._device, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 0d6809da..8fb2a7dd 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -42,13 +42,13 @@ def __init__( observation_shape: Shape, action_size: int, modules: Modules, - compile: bool, + compile_graph: bool, device: str, ): super().__init__(observation_shape, action_size, modules, device) self._compute_grad = ( CudaGraphWrapper(self.compute_grad) # type: ignore - if compile + if compile_graph else self.compute_grad ) @@ -109,7 +109,7 @@ def __init__( warmup_tokens: int, final_tokens: int, initial_learning_rate: float, - compile: bool, + compile_graph: bool, device: str, ): super().__init__( @@ -123,7 +123,7 @@ def __init__( self._initial_learning_rate = initial_learning_rate self._compute_grad = ( CudaGraphWrapper(self.compute_grad) # type: ignore - if compile + if compile_graph else self.compute_grad ) # TODO: Include stateful information in checkpoint. diff --git a/d3rlpy/optimizers/optimizers.py b/d3rlpy/optimizers/optimizers.py index 1719421c..ff1863d1 100644 --- a/d3rlpy/optimizers/optimizers.py +++ b/d3rlpy/optimizers/optimizers.py @@ -252,8 +252,6 @@ def create_optimizer( eps=self.eps, weight_decay=self.weight_decay, amsgrad=self.amsgrad, - capturable=False, - differentiable=False, ) @staticmethod diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 07357a7b..f0f1a9a2 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -139,7 +139,8 @@ def copy_recursively(src: _T, dst: _T) -> None: if isinstance(src, torch.Tensor) and isinstance(dst, torch.Tensor): dst.copy_(src) elif isinstance(src, (list, tuple)) and isinstance(dst, (list, tuple)): - [d.copy_(s) for s, d in zip(src, dst)] + for s, d in zip(src, dst): + d.copy_(s) else: raise ValueError( f"invalid inpu types: src={type(src)}, dst={type(dst)}" @@ -426,12 +427,12 @@ def unfreeze(self) -> None: def set_eval(self) -> None: for v in asdict_without_copy(self).values(): - if isinstance(v, nn.Module): + if isinstance(v, nn.Module) and v.training: v.eval() def set_train(self) -> None: for v in asdict_without_copy(self).values(): - if isinstance(v, nn.Module): + if isinstance(v, nn.Module) and not v.training: v.train() def reset_optimizer_states(self) -> None: @@ -499,32 +500,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return a * F.gelu(b) -BatchT = TypeVar( - "BatchT", +BatchT_contra = TypeVar( + "BatchT_contra", bound=Union[TorchMiniBatch, TorchTrajectoryMiniBatch], contravariant=True, ) -RetT = TypeVar("RetT", covariant=True) +RetT_co = TypeVar("RetT_co", covariant=True) -class CudaGraphFunc(Generic[BatchT, RetT], Protocol): - def __call__(self, batch: BatchT) -> RetT: ... +class CudaGraphFunc(Generic[BatchT_contra, RetT_co], Protocol): + def __call__(self, batch: BatchT_contra) -> RetT_co: ... -class CudaGraphWrapper(Generic[BatchT, RetT]): - _func: CudaGraphFunc[BatchT, RetT] +class CudaGraphWrapper(Generic[BatchT_contra, RetT_co]): + _func: CudaGraphFunc[BatchT_contra, RetT_co] _input: TorchTrajectoryMiniBatch _graph: Optional[CUDAGraph] - _inpt: Optional[BatchT] - _out: Optional[RetT] + _inpt: Optional[BatchT_contra] + _out: Optional[RetT_co] def __init__( self, - func: CudaGraphFunc[BatchT, RetT], + func: CudaGraphFunc[BatchT_contra, RetT_co], warmup_steps: int = 3, - compile: bool = True, + compile_func: bool = True, ): - self._func = torch.compile(func) if compile else func + self._func = torch.compile(func) if compile_func else func self._step = 0 self._graph = None self._inpt = None @@ -532,7 +533,7 @@ def __init__( self._warmup_steps = warmup_steps self._warmup_stream = torch.cuda.Stream() - def __call__(self, batch: BatchT) -> RetT: + def __call__(self, batch: BatchT_contra) -> RetT_co: if self._step < self._warmup_steps: # warmup self._warmup_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._warmup_stream): diff --git a/reproductions/finetuning/cal_ql_finetune.py b/reproductions/finetuning/cal_ql_finetune.py index d937c600..35a492d2 100644 --- a/reproductions/finetuning/cal_ql_finetune.py +++ b/reproductions/finetuning/cal_ql_finetune.py @@ -51,7 +51,6 @@ def main() -> None: alpha_threshold=0.8, reward_scaler=reward_scaler, max_q_backup=True, - compile=True, ).create(device=args.gpu) # pretraining diff --git a/reproductions/offline/cql.py b/reproductions/offline/cql.py index 29562508..3fe8c181 100644 --- a/reproductions/offline/cql.py +++ b/reproductions/offline/cql.py @@ -20,7 +20,7 @@ def main() -> None: encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256]) if "medium-v0" in args.dataset: - conservative_weight = 5.0 + conservative_weight = 10.0 else: conservative_weight = 5.0 @@ -35,8 +35,7 @@ def main() -> None: batch_size=256, n_action_samples=10, alpha_threshold=10, - conservative_weight=5.0, - compile=True, + conservative_weight=conservative_weight, ).create(device=args.gpu) cql.fit( diff --git a/reproductions/offline/decision_transformer.py b/reproductions/offline/decision_transformer.py index 4cfb4a97..f3849f86 100644 --- a/reproductions/offline/decision_transformer.py +++ b/reproductions/offline/decision_transformer.py @@ -46,7 +46,6 @@ def main() -> None: num_heads=1, num_layers=3, max_timestep=1000, - compile=False, ).create(device=args.gpu) dt.fit( From 54443e450c5fced21b7637420da1e56efab7677c Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 2 Nov 2024 15:09:10 +0900 Subject: [PATCH 03/15] Fix TD3 --- d3rlpy/algos/qlearning/td3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index ddad6eb0..d193c534 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -167,7 +167,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, - compile_graph=self._config.compile_graph, + compile_graph=self._config.compile_graph and "cuda" in self._device, device=self._device, ) From 303ed528a310265e9224e12e4f3a2595f60761a8 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 2 Nov 2024 16:45:53 +0900 Subject: [PATCH 04/15] Fix DiscreteSAC --- d3rlpy/algos/qlearning/torch/sac_impl.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 57ab0b39..c4283c05 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -244,7 +244,7 @@ def compute_actor_grad( def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() - loss = self._compute_critic_grad(batch) + loss = self._compute_actor_grad(batch) self._modules.actor_optim.step() return {"actor_loss": float(loss["loss"].cpu().detach().numpy())} @@ -259,7 +259,7 @@ def compute_actor_loss( loss = {} if self._modules.temp_optim: - loss.update(self.update_temp(batch, dist)) + loss.update(self.update_temp(dist)) log_probs = dist.logits probs = dist.probs @@ -271,9 +271,7 @@ def compute_actor_loss( loss["loss"] = (probs * (entropy - q_t)).sum(dim=1).mean() return loss - def update_temp( - self, batch: TorchMiniBatch, dist: Categorical - ) -> Dict[str, torch.Tensor]: + def update_temp(self, dist: Categorical) -> Dict[str, torch.Tensor]: assert self._modules.temp_optim assert self._modules.log_temp is not None self._modules.temp_optim.zero_grad() From 911d542c8f63ac6ea516a69b0e77c384ff43957b Mon Sep 17 00:00:00 2001 From: takuseno Date: Sat, 2 Nov 2024 16:46:16 +0900 Subject: [PATCH 05/15] Update torch dependency --- README.md | 2 +- requirements.txt | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6557916a..7b2fa9a9 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ d3rlpy supports Linux, macOS and Windows. ### Dependencies Installing d3rlpy package will install or upgrade the following packages to satisfy requirements: -- torch>=2.0.0 +- torch>=2.5.0 - tqdm>=4.66.3 - gym>=0.26.0 - gymnasium>=1.0.0 diff --git a/requirements.txt b/requirements.txt index e2514a46..604cca3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==2.0.1 +torch==2.5.0 tqdm>=4.66.1 h5py==2.10.0 gym==0.26.2 diff --git a/setup.py b/setup.py index 1b6c90ce..b6be1cb6 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ "Operating System :: MacOS :: MacOS X", ], install_requires=[ - "torch>=2.0.0", + "torch>=2.5.0", "tqdm>=4.66.3", "h5py", "gym>=0.26.0", From ebd4756cbd27eed95ac2e7d18d923e101b8504d0 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 10:03:05 +0900 Subject: [PATCH 06/15] Fix lint error --- d3rlpy/algos/qlearning/torch/bcq_impl.py | 2 +- d3rlpy/algos/qlearning/torch/bear_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/dqn_impl.py | 2 +- d3rlpy/algos/qlearning/torch/plas_impl.py | 2 +- d3rlpy/algos/qlearning/torch/sac_impl.py | 4 ++-- d3rlpy/algos/transformer/torch/decision_transformer_impl.py | 4 ++-- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index b3033508..bed815d6 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -91,7 +91,7 @@ def __init__( self._beta = beta self._rl_start_step = rl_start_step self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + CudaGraphWrapper(self.compute_imitator_grad) if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 3358d487..7aea6217 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -116,12 +116,12 @@ def __init__( self._vae_kl_weight = vae_kl_weight self._warmup_steps = warmup_steps self._compute_warmup_actor_grad = ( - CudaGraphWrapper(self.compute_warmup_actor_grad) # type: ignore + CudaGraphWrapper(self.compute_warmup_actor_grad) if compile_graph else self.compute_warmup_actor_grad ) self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + CudaGraphWrapper(self.compute_imitator_grad) if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 3d88dcfd..7a427147 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -87,12 +87,12 @@ def __init__( self._q_func_forwarder = q_func_forwarder self._targ_q_func_forwarder = targ_q_func_forwarder self._compute_critic_grad = ( - CudaGraphWrapper(self.compute_critic_grad) # type: ignore + CudaGraphWrapper(self.compute_critic_grad) if compile_graph else self.compute_critic_grad ) self._compute_actor_grad = ( - CudaGraphWrapper(self.compute_actor_grad) # type: ignore + CudaGraphWrapper(self.compute_actor_grad) if compile_graph else self.compute_actor_grad ) diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 0ecfaee9..e7835ff6 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -64,7 +64,7 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder self._target_update_interval = target_update_interval self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) # type: ignore + CudaGraphWrapper(self.compute_grad) if compile_graph else self.compute_grad ) diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 71620cdf..1d60010e 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -71,7 +71,7 @@ def __init__( self._beta = beta self._warmup_steps = warmup_steps self._compute_imitator_grad = ( - CudaGraphWrapper(self.compute_imitator_grad) # type: ignore + CudaGraphWrapper(self.compute_imitator_grad) if compile_graph else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index c4283c05..b9521e63 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -174,12 +174,12 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder self._target_update_interval = target_update_interval self._compute_critic_grad = ( - CudaGraphWrapper(self.compute_critic_grad) # type: ignore + CudaGraphWrapper(self.compute_critic_grad) if compile_graph else self.compute_critic_grad ) self._compute_actor_grad = ( - CudaGraphWrapper(self.compute_actor_grad) # type: ignore + CudaGraphWrapper(self.compute_actor_grad) if compile_graph else self.compute_actor_grad ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 8fb2a7dd..2b03c0b2 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -47,7 +47,7 @@ def __init__( ): super().__init__(observation_shape, action_size, modules, device) self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) # type: ignore + CudaGraphWrapper(self.compute_grad) if compile_graph else self.compute_grad ) @@ -122,7 +122,7 @@ def __init__( self._final_tokens = final_tokens self._initial_learning_rate = initial_learning_rate self._compute_grad = ( - CudaGraphWrapper(self.compute_grad) # type: ignore + CudaGraphWrapper(self.compute_grad) if compile_graph else self.compute_grad ) From bc3d1a197bbdc15d04cb374e1d0196f7c3cfb597 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 11:03:28 +0900 Subject: [PATCH 07/15] Add compiled flag to OptimizerWrapper --- d3rlpy/algos/qlearning/awac.py | 12 +++++-- d3rlpy/algos/qlearning/bc.py | 8 +++-- d3rlpy/algos/qlearning/bcq.py | 21 +++++++++--- d3rlpy/algos/qlearning/bear.py | 21 +++++++++--- d3rlpy/algos/qlearning/cal_ql.py | 19 ++++++++--- d3rlpy/algos/qlearning/cql.py | 27 ++++++++++++---- d3rlpy/algos/qlearning/crr.py | 12 +++++-- d3rlpy/algos/qlearning/ddpg.py | 12 +++++-- d3rlpy/algos/qlearning/dqn.py | 16 +++++++--- d3rlpy/algos/qlearning/iql.py | 12 +++++-- d3rlpy/algos/qlearning/nfq.py | 6 +++- d3rlpy/algos/qlearning/plas.py | 26 +++++++++++---- d3rlpy/algos/qlearning/rebrac.py | 12 +++++-- d3rlpy/algos/qlearning/sac.py | 32 ++++++++++++++----- d3rlpy/algos/qlearning/td3.py | 12 +++++-- d3rlpy/algos/qlearning/td3_plus_bc.py | 12 +++++-- .../algos/transformer/decision_transformer.py | 16 +++++++--- d3rlpy/ope/fqe.py | 8 +++-- d3rlpy/optimizers/optimizers.py | 16 ++++++++-- tests/models/test_optimizers.py | 10 +++--- tests/test_torch_utility.py | 8 ++--- 21 files changed, 236 insertions(+), 82 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index 03e93705..e82665a7 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -102,6 +102,8 @@ class AWAC(QLearningAlgoBase[AWACImpl, AWACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_normal_policy( observation_shape, action_size, @@ -132,10 +134,14 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) dummy_log_temp = Parameter(torch.zeros(1, 1)) @@ -160,7 +166,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index 0cb5e042..adb886ee 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -93,7 +93,9 @@ def inner_create_impl( raise ValueError(f"invalid policy_type: {self._config.policy_type}") optim = self._config.optim_factory.create( - imitator.named_modules(), lr=self._config.learning_rate + imitator.named_modules(), + lr=self._config.learning_rate, + compiled=False, ) modules = BCModules(optim=optim, imitator=imitator) @@ -168,7 +170,9 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - imitator.named_modules(), lr=self._config.learning_rate + imitator.named_modules(), + lr=self._config.learning_rate, + compiled=False, ) modules = DiscreteBCModules(optim=optim, imitator=imitator) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 810501fd..aa71e9bb 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -176,6 +176,8 @@ class BCQ(QLearningAlgoBase[BCQImpl, BCQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_residual_policy( observation_shape, action_size, @@ -230,15 +232,20 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) vae_optim = self._config.imitator_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, + compiled=compiled, ) modules = BCQModules( @@ -266,7 +273,7 @@ def inner_create_impl( action_flexibility=self._config.action_flexibility, beta=self._config.beta, rl_start_step=self._config.rl_start_step, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) @@ -364,6 +371,8 @@ class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -407,7 +416,9 @@ def inner_create_impl( q_func_params = list(q_funcs.named_modules()) imitator_params = list(imitator.named_modules()) optim = self._config.optim_factory.create( - q_func_params + imitator_params, lr=self._config.learning_rate + q_func_params + imitator_params, + lr=self._config.learning_rate, + compiled=compiled, ) modules = DiscreteBCQModules( @@ -427,7 +438,7 @@ def inner_create_impl( gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 034200b7..64c3f2c7 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -162,6 +162,8 @@ class BEAR(QLearningAlgoBase[BEARImpl, BEARConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_normal_policy( observation_shape, action_size, @@ -219,21 +221,30 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) vae_optim = self._config.imitator_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, + compiled=compiled, ) temp_optim = self._config.temp_optim_factory.create( - log_temp.named_modules(), lr=self._config.temp_learning_rate + log_temp.named_modules(), + lr=self._config.temp_learning_rate, + compiled=compiled, ) alpha_optim = self._config.alpha_optim_factory.create( - log_alpha.named_modules(), lr=self._config.actor_learning_rate + log_alpha.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) modules = BEARModules( @@ -268,7 +279,7 @@ def inner_create_impl( mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, warmup_steps=self._config.warmup_steps, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py index a6fd2a40..67ef79b1 100644 --- a/d3rlpy/algos/qlearning/cal_ql.py +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -89,6 +89,7 @@ def inner_create_impl( assert not ( self._config.soft_q_backup and self._config.max_q_backup ), "soft_q_backup and max_q_backup are mutually exclusive." + compiled = self._config.compile_graph and "cuda" in self._device policy = create_normal_policy( observation_shape, @@ -129,20 +130,28 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) if self._config.temp_learning_rate > 0: temp_optim = self._config.temp_optim_factory.create( - log_temp.named_modules(), lr=self._config.temp_learning_rate + log_temp.named_modules(), + lr=self._config.temp_learning_rate, + compiled=compiled, ) else: temp_optim = None if self._config.alpha_learning_rate > 0: alpha_optim = self._config.alpha_optim_factory.create( - log_alpha.named_modules(), lr=self._config.alpha_learning_rate + log_alpha.named_modules(), + lr=self._config.alpha_learning_rate, + compiled=compiled, ) else: alpha_optim = None @@ -172,7 +181,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 526f9c94..5d860346 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -144,6 +144,7 @@ def inner_create_impl( assert not ( self._config.soft_q_backup and self._config.max_q_backup ), "soft_q_backup and max_q_backup are mutually exclusive." + compiled = self._config.compile_graph and "cuda" in self._device policy = create_normal_policy( observation_shape, @@ -184,20 +185,28 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) if self._config.temp_learning_rate > 0: temp_optim = self._config.temp_optim_factory.create( - log_temp.named_modules(), lr=self._config.temp_learning_rate + log_temp.named_modules(), + lr=self._config.temp_learning_rate, + compiled=compiled, ) else: temp_optim = None if self._config.alpha_learning_rate > 0: alpha_optim = self._config.alpha_optim_factory.create( - log_alpha.named_modules(), lr=self._config.alpha_learning_rate + log_alpha.named_modules(), + lr=self._config.alpha_learning_rate, + compiled=compiled, ) else: alpha_optim = None @@ -227,7 +236,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) @@ -303,6 +312,8 @@ class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -323,7 +334,9 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_funcs.named_modules(), lr=self._config.learning_rate + q_funcs.named_modules(), + lr=self._config.learning_rate, + compiled=compiled, ) modules = DQNModules( @@ -341,7 +354,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, alpha=self._config.alpha, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index 46b3b981..3b06d62d 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -137,6 +137,8 @@ class CRR(QLearningAlgoBase[CRRImpl, CRRConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_normal_policy( observation_shape, action_size, @@ -171,10 +173,14 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) modules = CRRModules( @@ -201,7 +207,7 @@ def inner_create_impl( tau=self._config.tau, target_update_type=self._config.target_update_type, target_update_interval=self._config.target_update_interval, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 3144d160..e63ee92c 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -98,6 +98,8 @@ class DDPG(QLearningAlgoBase[DDPGImpl, DDPGConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_policy( observation_shape, action_size, @@ -132,10 +134,14 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) modules = DDPGModules( @@ -155,7 +161,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile_graph=self._config.compile_graph, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index a4a7e8f7..7175721b 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -71,6 +71,8 @@ class DQN(QLearningAlgoBase[DQNImpl, DQNConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + q_funcs, forwarder = create_discrete_q_function( observation_shape, action_size, @@ -91,7 +93,9 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_funcs.named_modules(), lr=self._config.learning_rate + q_funcs.named_modules(), + lr=self._config.learning_rate, + compiled=compiled, ) modules = DQNModules( @@ -108,7 +112,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, modules=modules, gamma=self._config.gamma, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) @@ -181,6 +185,8 @@ class DoubleDQN(DQN): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + q_funcs, forwarder = create_discrete_q_function( observation_shape, action_size, @@ -201,7 +207,9 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_funcs.named_modules(), lr=self._config.learning_rate + q_funcs.named_modules(), + lr=self._config.learning_rate, + compiled=compiled, ) modules = DQNModules( @@ -218,7 +226,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index a95b79f2..324c57a4 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -113,6 +113,8 @@ class IQL(QLearningAlgoBase[IQLImpl, IQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_normal_policy( observation_shape, action_size, @@ -149,12 +151,16 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) q_func_params = list(q_funcs.named_modules()) v_func_params = list(value_func.named_modules()) critic_optim = self._config.critic_optim_factory.create( - q_func_params + v_func_params, lr=self._config.critic_learning_rate + q_func_params + v_func_params, + lr=self._config.critic_learning_rate, + compiled=compiled, ) modules = IQLModules( @@ -177,7 +183,7 @@ def inner_create_impl( expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 9855fa1c..6ff0a88b 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -73,6 +73,8 @@ class NFQ(QLearningAlgoBase[DQNImpl, NFQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -93,7 +95,9 @@ def inner_create_impl( ) optim = self._config.optim_factory.create( - q_funcs.named_modules(), lr=self._config.learning_rate + q_funcs.named_modules(), + lr=self._config.learning_rate, + compiled=compiled, ) modules = DQNModules( diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index e5abc75a..70131038 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -113,6 +113,8 @@ class PLAS(QLearningAlgoBase[PLASImpl, PLASConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_policy( observation_shape, 2 * action_size, @@ -165,15 +167,20 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) vae_optim = self._config.critic_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, + compiled=compiled, ) modules = PLASModules( @@ -199,7 +206,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) @@ -271,6 +278,8 @@ class PLASWithPerturbation(PLAS): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_policy( observation_shape, 2 * action_size, @@ -341,15 +350,20 @@ def inner_create_impl( named_modules = list(policy.named_modules()) named_modules += list(perturbation.named_modules()) actor_optim = self._config.actor_optim_factory.create( - named_modules, lr=self._config.actor_learning_rate + named_modules, + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) vae_optim = self._config.critic_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, + compiled=compiled, ) modules = PLASWithPerturbationModules( @@ -377,7 +391,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/rebrac.py b/d3rlpy/algos/qlearning/rebrac.py index cc237910..b44fabee 100644 --- a/d3rlpy/algos/qlearning/rebrac.py +++ b/d3rlpy/algos/qlearning/rebrac.py @@ -106,6 +106,8 @@ class ReBRAC(QLearningAlgoBase[ReBRACImpl, ReBRACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_policy( observation_shape, action_size, @@ -140,10 +142,14 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) modules = DDPGModules( @@ -168,7 +174,7 @@ def inner_create_impl( actor_beta=self._config.actor_beta, critic_beta=self._config.critic_beta, update_actor_interval=self._config.update_actor_interval, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 71d8b8e9..a3136c55 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -127,6 +127,8 @@ class SAC(QLearningAlgoBase[SACImpl, SACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_normal_policy( observation_shape, action_size, @@ -160,14 +162,20 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) if self._config.temp_learning_rate > 0: temp_optim = self._config.temp_optim_factory.create( - log_temp.named_modules(), lr=self._config.temp_learning_rate + log_temp.named_modules(), + lr=self._config.temp_learning_rate, + compiled=compiled, ) else: temp_optim = None @@ -190,7 +198,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) @@ -285,6 +293,8 @@ class DiscreteSAC(QLearningAlgoBase[DiscreteSACImpl, DiscreteSACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -321,15 +331,21 @@ def inner_create_impl( log_temp = None critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) if self._config.temp_learning_rate > 0: assert log_temp is not None temp_optim = self._config.temp_optim_factory.create( - log_temp.named_modules(), lr=self._config.temp_learning_rate + log_temp.named_modules(), + lr=self._config.temp_learning_rate, + compiled=compiled, ) else: temp_optim = None @@ -352,7 +368,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index d193c534..88fccc94 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -107,6 +107,8 @@ class TD3(QLearningAlgoBase[TD3Impl, TD3Config]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_policy( observation_shape, action_size, @@ -141,10 +143,14 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) modules = DDPGModules( @@ -167,7 +173,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index bdbda4b6..78a5bede 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -99,6 +99,8 @@ class TD3PlusBC(QLearningAlgoBase[TD3PlusBCImpl, TD3PlusBCConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + policy = create_deterministic_policy( observation_shape, action_size, @@ -133,10 +135,14 @@ def inner_create_impl( ) actor_optim = self._config.actor_optim_factory.create( - policy.named_modules(), lr=self._config.actor_learning_rate + policy.named_modules(), + lr=self._config.actor_learning_rate, + compiled=compiled, ) critic_optim = self._config.critic_optim_factory.create( - q_funcs.named_modules(), lr=self._config.critic_learning_rate + q_funcs.named_modules(), + lr=self._config.critic_learning_rate, + compiled=compiled, ) modules = DDPGModules( @@ -160,7 +166,7 @@ def inner_create_impl( target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, update_actor_interval=self._config.update_actor_interval, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index faba155e..015cae08 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -89,6 +89,8 @@ class DecisionTransformer( def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + transformer = create_continuous_decision_transformer( observation_shape=observation_shape, action_size=action_size, @@ -106,7 +108,9 @@ def inner_create_impl( enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( - transformer.named_modules(), lr=self._config.learning_rate + transformer.named_modules(), + lr=self._config.learning_rate, + compiled=compiled, ) modules = DecisionTransformerModules( @@ -119,7 +123,7 @@ def inner_create_impl( action_size=action_size, modules=modules, device=self._device, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, ) def get_action_type(self) -> ActionSpace: @@ -198,6 +202,8 @@ class DiscreteDecisionTransformer( def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: + compiled = self._config.compile_graph and "cuda" in self._device + transformer = create_discrete_decision_transformer( observation_shape=observation_shape, action_size=action_size, @@ -216,7 +222,9 @@ def inner_create_impl( enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( - transformer.named_modules(), lr=self._config.learning_rate + transformer.named_modules(), + lr=self._config.learning_rate, + compiled=compiled, ) modules = DiscreteDecisionTransformerModules( @@ -231,7 +239,7 @@ def inner_create_impl( warmup_tokens=self._config.warmup_tokens, final_tokens=self._config.final_tokens, initial_learning_rate=self._config.learning_rate, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=compiled, device=self._device, ) diff --git a/d3rlpy/ope/fqe.py b/d3rlpy/ope/fqe.py index ad2333d0..52b7b075 100644 --- a/d3rlpy/ope/fqe.py +++ b/d3rlpy/ope/fqe.py @@ -167,7 +167,9 @@ def inner_create_impl( enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( - q_funcs.named_modules(), lr=self._config.learning_rate + q_funcs.named_modules(), + lr=self._config.learning_rate, + compiled=False, ) modules = FQEBaseModules( @@ -245,7 +247,9 @@ def inner_create_impl( enable_ddp=self._enable_ddp, ) optim = self._config.optim_factory.create( - q_funcs.named_modules(), lr=self._config.learning_rate + q_funcs.named_modules(), + lr=self._config.learning_rate, + compiled=False, ) modules = FQEBaseModules( q_funcs=q_funcs, diff --git a/d3rlpy/optimizers/optimizers.py b/d3rlpy/optimizers/optimizers.py index ff1863d1..ced16e9b 100644 --- a/d3rlpy/optimizers/optimizers.py +++ b/d3rlpy/optimizers/optimizers.py @@ -42,11 +42,13 @@ class OptimizerWrapper: Args: params: List of torch parameters. optim: PyTorch optimizer. + compiled: Flag to be True if CudaGraph and torch.compile are applied. clip_grad_norm: Maximum norm value of gradients to clip. """ _params: Sequence[nn.Parameter] _optim: Optimizer + _compiled: bool _clip_grad_norm: Optional[float] _lr_scheduler: Optional[LRScheduler] @@ -54,16 +56,18 @@ def __init__( self, params: Sequence[nn.Parameter], optim: Optimizer, + compiled: bool, clip_grad_norm: Optional[float] = None, lr_scheduler: Optional[LRScheduler] = None, ): self._params = params self._optim = optim + self._compiled = compiled self._clip_grad_norm = clip_grad_norm self._lr_scheduler = lr_scheduler - def zero_grad(self, set_to_none: bool = False) -> None: - self._optim.zero_grad(set_to_none=set_to_none) + def zero_grad(self) -> None: + self._optim.zero_grad(set_to_none=self._compiled) def step(self) -> None: """Updates parameters. @@ -103,13 +107,18 @@ class OptimizerFactory(DynamicConfig): ) def create( - self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + self, + named_modules: Iterable[Tuple[str, nn.Module]], + lr: float, + compiled: bool, ) -> OptimizerWrapper: """Returns an optimizer object. Args: named_modules (list): List of tuples of module names and modules. lr (float): Learning rate. + compiled (bool): Flag to be True if CudaGraph and torch.compile are + applied. Returns: OptimizerWrapper object. @@ -120,6 +129,7 @@ def create( return OptimizerWrapper( params=params, optim=optim, + compiled=compiled, clip_grad_norm=self.clip_grad_norm, lr_scheduler=( self.lr_scheduler_factory.create(optim) diff --git a/tests/models/test_optimizers.py b/tests/models/test_optimizers.py index 1f98d03a..0bc57767 100644 --- a/tests/models/test_optimizers.py +++ b/tests/models/test_optimizers.py @@ -17,7 +17,7 @@ def test_sgd_factory(lr: float, module: torch.nn.Module) -> None: factory = SGDFactory() - optim = factory.create(module.named_modules(), lr) + optim = factory.create(module.named_modules(), lr, False) assert isinstance(optim.optim, SGD) assert optim.optim.defaults["lr"] == lr @@ -31,7 +31,7 @@ def test_sgd_factory(lr: float, module: torch.nn.Module) -> None: def test_adam_factory(lr: float, module: torch.nn.Module) -> None: factory = AdamFactory() - optim = factory.create(module.named_modules(), lr) + optim = factory.create(module.named_modules(), lr, False) assert isinstance(optim.optim, Adam) assert optim.optim.defaults["lr"] == lr @@ -45,7 +45,7 @@ def test_adam_factory(lr: float, module: torch.nn.Module) -> None: def test_adam_w_factory(lr: float, module: torch.nn.Module) -> None: factory = AdamWFactory() - optim = factory.create(module.named_modules(), lr) + optim = factory.create(module.named_modules(), lr, False) assert isinstance(optim.optim, AdamW) assert optim.optim.defaults["lr"] == lr @@ -59,7 +59,7 @@ def test_adam_w_factory(lr: float, module: torch.nn.Module) -> None: def test_rmsprop_factory(lr: float, module: torch.nn.Module) -> None: factory = RMSpropFactory() - optim = factory.create(module.named_modules(), lr) + optim = factory.create(module.named_modules(), lr, False) assert isinstance(optim.optim, RMSprop) assert optim.optim.defaults["lr"] == lr @@ -81,7 +81,7 @@ def __init__(self) -> None: module = M() - optim = factory.create(module.named_modules(), lr) + optim = factory.create(module.named_modules(), lr, False) assert isinstance(optim.optim, AdamW) assert optim.optim.defaults["lr"] == lr diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index 9f96484e..9c5d0b70 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -116,7 +116,7 @@ def __init__(self) -> None: self.fc1 = torch.nn.Linear(100, 100) self.fc2 = torch.nn.Linear(100, 100) params = list(self.fc1.parameters()) - self.optim = OptimizerWrapper(params, torch.optim.Adam(params)) + self.optim = OptimizerWrapper(params, torch.optim.Adam(params), False) self.modules = DummyModules(self.fc1, self.optim) self.device = "cpu:0" @@ -191,7 +191,7 @@ class DummyModules(Modules): def test_modules() -> None: fc = torch.nn.Linear(100, 200) params = list(fc.parameters()) - optim = OptimizerWrapper(params, torch.optim.Adam(params)) + optim = OptimizerWrapper(params, torch.optim.Adam(params), False) modules = DummyModules(fc, optim) # check checkpointer @@ -402,7 +402,7 @@ def test_checkpointer() -> None: fc1 = torch.nn.Linear(100, 100) fc2 = torch.nn.Linear(100, 100) params = list(fc1.parameters()) - optim = OptimizerWrapper(params, torch.optim.Adam(params)) + optim = OptimizerWrapper(params, torch.optim.Adam(params), False) checkpointer = Checkpointer( modules={"fc1": fc1, "fc2": fc2, "optim": optim}, device="cpu:0" ) @@ -424,7 +424,7 @@ def test_checkpointer() -> None: fc1_2 = torch.nn.Linear(100, 100) fc2_2 = torch.nn.Linear(100, 100) params_2 = list(fc1_2.parameters()) - optim_2 = OptimizerWrapper(params_2, torch.optim.Adam(params_2)) + optim_2 = OptimizerWrapper(params_2, torch.optim.Adam(params_2), False) checkpointer = Checkpointer( modules={"fc1": fc1_2, "fc2": fc2_2, "optim": optim_2}, device="cpu:0" ) From 78208738cb5901eb960d65596f93014bafc6d6a8 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 14:14:14 +0900 Subject: [PATCH 08/15] Workaround DiscreteSAC test --- tests/algos/qlearning/test_sac.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algos/qlearning/test_sac.py b/tests/algos/qlearning/test_sac.py index 072e85af..bea8ee74 100644 --- a/tests/algos/qlearning/test_sac.py +++ b/tests/algos/qlearning/test_sac.py @@ -43,7 +43,7 @@ def test_sac( @pytest.mark.parametrize( - "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))] + "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] ) @pytest.mark.parametrize( "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] From 0f370659c4d3f492a7a5d4932b0e6e56998140c7 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 15:28:42 +0900 Subject: [PATCH 09/15] Add compiled property --- d3rlpy/algos/qlearning/awac.py | 9 +++----- d3rlpy/algos/qlearning/bcq.py | 18 +++++---------- d3rlpy/algos/qlearning/bear.py | 15 +++++-------- d3rlpy/algos/qlearning/cal_ql.py | 12 +++++----- d3rlpy/algos/qlearning/cql.py | 20 ++++++----------- d3rlpy/algos/qlearning/crr.py | 9 +++----- d3rlpy/algos/qlearning/ddpg.py | 9 +++----- d3rlpy/algos/qlearning/dqn.py | 14 ++++-------- d3rlpy/algos/qlearning/iql.py | 9 +++----- d3rlpy/algos/qlearning/nfq.py | 7 ++---- d3rlpy/algos/qlearning/plas.py | 21 +++++++----------- d3rlpy/algos/qlearning/rebrac.py | 9 +++----- d3rlpy/algos/qlearning/sac.py | 22 +++++++------------ d3rlpy/algos/qlearning/td3.py | 9 +++----- d3rlpy/algos/qlearning/td3_plus_bc.py | 9 +++----- .../algos/transformer/decision_transformer.py | 12 ++++------ d3rlpy/base.py | 13 +++++++++++ 17 files changed, 84 insertions(+), 133 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index e82665a7..d0662795 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -86,7 +86,6 @@ class AWACConfig(LearnableConfig): lam: float = 1.0 n_action_samples: int = 1 n_critics: int = 2 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -102,8 +101,6 @@ class AWAC(QLearningAlgoBase[AWACImpl, AWACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -136,12 +133,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) dummy_log_temp = Parameter(torch.zeros(1, 1)) @@ -166,7 +163,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index aa71e9bb..65a599f2 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -160,7 +160,6 @@ class BCQConfig(LearnableConfig): action_flexibility: float = 0.05 rl_start_step: int = 0 beta: float = 0.5 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -176,8 +175,6 @@ class BCQ(QLearningAlgoBase[BCQImpl, BCQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_residual_policy( observation_shape, action_size, @@ -234,18 +231,18 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) vae_optim = self._config.imitator_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = BCQModules( @@ -273,7 +270,7 @@ def inner_create_impl( action_flexibility=self._config.action_flexibility, beta=self._config.beta, rl_start_step=self._config.rl_start_step, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) @@ -355,7 +352,6 @@ class DiscreteBCQConfig(LearnableConfig): beta: float = 0.5 target_update_interval: int = 8000 share_encoder: bool = True - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -371,8 +367,6 @@ class DiscreteBCQ(QLearningAlgoBase[DiscreteBCQImpl, DiscreteBCQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -418,7 +412,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_func_params + imitator_params, lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DiscreteBCQModules( @@ -438,7 +432,7 @@ def inner_create_impl( gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 64c3f2c7..5aa00aad 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -146,7 +146,6 @@ class BEARConfig(LearnableConfig): mmd_sigma: float = 20.0 vae_kl_weight: float = 0.5 warmup_steps: int = 40000 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -162,8 +161,6 @@ class BEAR(QLearningAlgoBase[BEARImpl, BEARConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -223,28 +220,28 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) vae_optim = self._config.imitator_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, - compiled=compiled, + compiled=self.compiled, ) temp_optim = self._config.temp_optim_factory.create( log_temp.named_modules(), lr=self._config.temp_learning_rate, - compiled=compiled, + compiled=self.compiled, ) alpha_optim = self._config.alpha_optim_factory.create( log_alpha.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = BEARModules( @@ -279,7 +276,7 @@ def inner_create_impl( mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, warmup_steps=self._config.warmup_steps, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py index 67ef79b1..df4782a1 100644 --- a/d3rlpy/algos/qlearning/cal_ql.py +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -89,8 +89,6 @@ def inner_create_impl( assert not ( self._config.soft_q_backup and self._config.max_q_backup ), "soft_q_backup and max_q_backup are mutually exclusive." - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -132,18 +130,18 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) if self._config.temp_learning_rate > 0: temp_optim = self._config.temp_optim_factory.create( log_temp.named_modules(), lr=self._config.temp_learning_rate, - compiled=compiled, + compiled=self.compiled, ) else: temp_optim = None @@ -151,7 +149,7 @@ def inner_create_impl( alpha_optim = self._config.alpha_optim_factory.create( log_alpha.named_modules(), lr=self._config.alpha_learning_rate, - compiled=compiled, + compiled=self.compiled, ) else: alpha_optim = None @@ -181,7 +179,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 5d860346..f1107752 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -125,7 +125,6 @@ class CQLConfig(LearnableConfig): n_action_samples: int = 10 soft_q_backup: bool = False max_q_backup: bool = False - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -144,8 +143,6 @@ def inner_create_impl( assert not ( self._config.soft_q_backup and self._config.max_q_backup ), "soft_q_backup and max_q_backup are mutually exclusive." - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -187,18 +184,18 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) if self._config.temp_learning_rate > 0: temp_optim = self._config.temp_optim_factory.create( log_temp.named_modules(), lr=self._config.temp_learning_rate, - compiled=compiled, + compiled=self.compiled, ) else: temp_optim = None @@ -206,7 +203,7 @@ def inner_create_impl( alpha_optim = self._config.alpha_optim_factory.create( log_alpha.named_modules(), lr=self._config.alpha_learning_rate, - compiled=compiled, + compiled=self.compiled, ) else: alpha_optim = None @@ -236,7 +233,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) @@ -296,7 +293,6 @@ class DiscreteCQLConfig(LearnableConfig): n_critics: int = 1 target_update_interval: int = 8000 alpha: float = 1.0 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -312,8 +308,6 @@ class DiscreteCQL(QLearningAlgoBase[DiscreteCQLImpl, DiscreteCQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -336,7 +330,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.named_modules(), lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DQNModules( @@ -354,7 +348,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, alpha=self._config.alpha, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index 3b06d62d..26c37558 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -121,7 +121,6 @@ class CRRConfig(LearnableConfig): tau: float = 5e-3 target_update_interval: int = 100 update_actor_interval: int = 1 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -137,8 +136,6 @@ class CRR(QLearningAlgoBase[CRRImpl, CRRConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -175,12 +172,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = CRRModules( @@ -207,7 +204,7 @@ def inner_create_impl( tau=self._config.tau, target_update_type=self._config.target_update_type, target_update_interval=self._config.target_update_interval, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index e63ee92c..15daba0b 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -82,7 +82,6 @@ class DDPGConfig(LearnableConfig): q_func_factory: QFunctionFactory = make_q_func_field() tau: float = 0.005 n_critics: int = 1 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -98,8 +97,6 @@ class DDPG(QLearningAlgoBase[DDPGImpl, DDPGConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_policy( observation_shape, action_size, @@ -136,12 +133,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DDPGModules( @@ -161,7 +158,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 7175721b..67294d5e 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -55,7 +55,6 @@ class DQNConfig(LearnableConfig): gamma: float = 0.99 n_critics: int = 1 target_update_interval: int = 8000 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -71,8 +70,6 @@ class DQN(QLearningAlgoBase[DQNImpl, DQNConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - q_funcs, forwarder = create_discrete_q_function( observation_shape, action_size, @@ -95,7 +92,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.named_modules(), lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DQNModules( @@ -112,7 +109,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, modules=modules, gamma=self._config.gamma, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) @@ -169,7 +166,6 @@ class DoubleDQNConfig(DQNConfig): gamma: float = 0.99 n_critics: int = 1 target_update_interval: int = 8000 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -185,8 +181,6 @@ class DoubleDQN(DQN): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - q_funcs, forwarder = create_discrete_q_function( observation_shape, action_size, @@ -209,7 +203,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.named_modules(), lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DQNModules( @@ -226,7 +220,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 324c57a4..598f2ece 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -97,7 +97,6 @@ class IQLConfig(LearnableConfig): expectile: float = 0.7 weight_temp: float = 3.0 max_weight: float = 100.0 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -113,8 +112,6 @@ class IQL(QLearningAlgoBase[IQLImpl, IQLConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -153,14 +150,14 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) q_func_params = list(q_funcs.named_modules()) v_func_params = list(value_func.named_modules()) critic_optim = self._config.critic_optim_factory.create( q_func_params + v_func_params, lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = IQLModules( @@ -183,7 +180,7 @@ def inner_create_impl( expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 6ff0a88b..b6346dfd 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -57,7 +57,6 @@ class NFQConfig(LearnableConfig): batch_size: int = 32 gamma: float = 0.99 n_critics: int = 1 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -73,8 +72,6 @@ class NFQ(QLearningAlgoBase[DQNImpl, NFQConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -97,7 +94,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( q_funcs.named_modules(), lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DQNModules( @@ -114,7 +111,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=1, gamma=self._config.gamma, - compile_graph=self._config.compile_graph and "cuda" in self._device, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 70131038..0b254e09 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -97,7 +97,6 @@ class PLASConfig(LearnableConfig): lam: float = 0.75 warmup_steps: int = 500000 beta: float = 0.5 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -113,8 +112,6 @@ class PLAS(QLearningAlgoBase[PLASImpl, PLASConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_policy( observation_shape, 2 * action_size, @@ -169,18 +166,18 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) vae_optim = self._config.critic_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = PLASModules( @@ -206,7 +203,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) @@ -278,8 +275,6 @@ class PLASWithPerturbation(PLAS): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_policy( observation_shape, 2 * action_size, @@ -352,18 +347,18 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( named_modules, lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) vae_optim = self._config.critic_optim_factory.create( list(vae_encoder.named_modules()) + list(vae_decoder.named_modules()), lr=self._config.imitator_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = PLASWithPerturbationModules( @@ -391,7 +386,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/rebrac.py b/d3rlpy/algos/qlearning/rebrac.py index b44fabee..9ab75854 100644 --- a/d3rlpy/algos/qlearning/rebrac.py +++ b/d3rlpy/algos/qlearning/rebrac.py @@ -90,7 +90,6 @@ class ReBRACConfig(LearnableConfig): actor_beta: float = 0.001 critic_beta: float = 0.01 update_actor_interval: int = 2 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -106,8 +105,6 @@ class ReBRAC(QLearningAlgoBase[ReBRACImpl, ReBRACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_policy( observation_shape, action_size, @@ -144,12 +141,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DDPGModules( @@ -174,7 +171,7 @@ def inner_create_impl( actor_beta=self._config.actor_beta, critic_beta=self._config.critic_beta, update_actor_interval=self._config.update_actor_interval, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index a3136c55..4b12d42d 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -111,7 +111,6 @@ class SACConfig(LearnableConfig): tau: float = 0.005 n_critics: int = 2 initial_temperature: float = 1.0 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -127,8 +126,6 @@ class SAC(QLearningAlgoBase[SACImpl, SACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_normal_policy( observation_shape, action_size, @@ -164,18 +161,18 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) if self._config.temp_learning_rate > 0: temp_optim = self._config.temp_optim_factory.create( log_temp.named_modules(), lr=self._config.temp_learning_rate, - compiled=compiled, + compiled=self.compiled, ) else: temp_optim = None @@ -198,7 +195,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) @@ -277,7 +274,6 @@ class DiscreteSACConfig(LearnableConfig): n_critics: int = 2 initial_temperature: float = 1.0 target_update_interval: int = 8000 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -293,8 +289,6 @@ class DiscreteSAC(QLearningAlgoBase[DiscreteSACImpl, DiscreteSACConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - q_funcs, q_func_forwarder = create_discrete_q_function( observation_shape, action_size, @@ -333,19 +327,19 @@ def inner_create_impl( critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) if self._config.temp_learning_rate > 0: assert log_temp is not None temp_optim = self._config.temp_optim_factory.create( log_temp.named_modules(), lr=self._config.temp_learning_rate, - compiled=compiled, + compiled=self.compiled, ) else: temp_optim = None @@ -368,7 +362,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index 88fccc94..af44fc31 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -91,7 +91,6 @@ class TD3Config(LearnableConfig): target_smoothing_sigma: float = 0.2 target_smoothing_clip: float = 0.5 update_actor_interval: int = 2 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -107,8 +106,6 @@ class TD3(QLearningAlgoBase[TD3Impl, TD3Config]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_policy( observation_shape, action_size, @@ -145,12 +142,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DDPGModules( @@ -173,7 +170,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 78a5bede..3ed57cf1 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -83,7 +83,6 @@ class TD3PlusBCConfig(LearnableConfig): target_smoothing_clip: float = 0.5 alpha: float = 2.5 update_actor_interval: int = 2 - compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -99,8 +98,6 @@ class TD3PlusBC(QLearningAlgoBase[TD3PlusBCImpl, TD3PlusBCConfig]): def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - policy = create_deterministic_policy( observation_shape, action_size, @@ -137,12 +134,12 @@ def inner_create_impl( actor_optim = self._config.actor_optim_factory.create( policy.named_modules(), lr=self._config.actor_learning_rate, - compiled=compiled, + compiled=self.compiled, ) critic_optim = self._config.critic_optim_factory.create( q_funcs.named_modules(), lr=self._config.critic_learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DDPGModules( @@ -166,7 +163,7 @@ def inner_create_impl( target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, update_actor_interval=self._config.update_actor_interval, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 015cae08..870a679a 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -89,8 +89,6 @@ class DecisionTransformer( def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - transformer = create_continuous_decision_transformer( observation_shape=observation_shape, action_size=action_size, @@ -110,7 +108,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( transformer.named_modules(), lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DecisionTransformerModules( @@ -123,7 +121,7 @@ def inner_create_impl( action_size=action_size, modules=modules, device=self._device, - compile_graph=compiled, + compile_graph=self.compiled, ) def get_action_type(self) -> ActionSpace: @@ -202,8 +200,6 @@ class DiscreteDecisionTransformer( def inner_create_impl( self, observation_shape: Shape, action_size: int ) -> None: - compiled = self._config.compile_graph and "cuda" in self._device - transformer = create_discrete_decision_transformer( observation_shape=observation_shape, action_size=action_size, @@ -224,7 +220,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( transformer.named_modules(), lr=self._config.learning_rate, - compiled=compiled, + compiled=self.compiled, ) modules = DiscreteDecisionTransformerModules( @@ -239,7 +235,7 @@ def inner_create_impl( warmup_tokens=self._config.warmup_tokens, final_tokens=self._config.final_tokens, initial_learning_rate=self._config.learning_rate, - compile_graph=compiled, + compile_graph=self.compiled, device=self._device, ) diff --git a/d3rlpy/base.py b/d3rlpy/base.py index 2b526373..cb9541ee 100644 --- a/d3rlpy/base.py +++ b/d3rlpy/base.py @@ -96,6 +96,7 @@ class LearnableConfig(DynamicConfig): ) action_scaler: Optional[ActionScaler] = make_action_scaler_field() reward_scaler: Optional[RewardScaler] = make_reward_scaler_field() + compile_graph: bool = False def create( self, device: DeviceArg = False, enable_ddp: bool = False @@ -356,6 +357,18 @@ def config(self) -> TConfig_co: """ return self._config + @property + def compiled(self) -> bool: + """Compiled flag. + + This represents if computational graph is optimized with CudaGraph and + torch.compile. + + Returns: + bool: True if compiled. + """ + return self._config.compile_graph and "cuda" in self._device + @property def batch_size(self) -> int: """Batch size to train. From 92d4013e70bb9f4684c29856317ee69509dd9920 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 16:36:09 +0900 Subject: [PATCH 10/15] Add tests --- tests/test_torch_utility.py | 56 +++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/test_torch_utility.py b/tests/test_torch_utility.py index 9c5d0b70..d6412682 100644 --- a/tests/test_torch_utility.py +++ b/tests/test_torch_utility.py @@ -18,6 +18,7 @@ TorchMiniBatch, TorchTrajectoryMiniBatch, View, + copy_recursively, eval_api, get_batch_size, get_device, @@ -168,6 +169,19 @@ def test_to_cpu() -> None: pass +def test_copy_recursively() -> None: + x = torch.rand(10) + y = torch.rand(10) + copy_recursively(x, y) + assert torch.all(x == y) + + x_list = [torch.rand(10), torch.rand(20)] + y_list = [torch.rand(10), torch.rand(20)] + copy_recursively(x_list, y_list) + assert torch.all(x_list[0] == y_list[0]) + assert torch.all(x_list[1] == y_list[1]) + + def test_get_device() -> None: x = torch.rand(10) assert get_device(x) == "cpu" @@ -323,6 +337,29 @@ def test_torch_mini_batch( assert np.all(torch_batch.terminals.numpy() == batch.terminals) assert np.all(torch_batch.intervals.numpy() == batch.intervals) + torch_batch2 = TorchMiniBatch( + observations=torch.zeros_like(torch_batch.observations), + actions=torch.zeros_like(torch_batch.actions), + rewards=torch.zeros_like(torch_batch.rewards), + next_observations=torch.zeros_like(torch_batch.next_observations), + next_actions=torch.zeros_like(torch_batch.next_actions), + returns_to_go=torch.zeros_like(torch_batch.returns_to_go), + terminals=torch.zeros_like(torch_batch.terminals), + intervals=torch.zeros_like(torch_batch.intervals), + device=torch_batch.device, + ) + torch_batch2.copy_(torch_batch) + assert torch.all(torch_batch2.observations == torch_batch.observations) + assert torch.all(torch_batch2.actions == torch_batch.actions) + assert torch.all(torch_batch2.rewards == torch_batch.rewards) + assert torch.all( + torch_batch2.next_observations == torch_batch.next_observations + ) + assert torch.all(torch_batch2.next_actions == torch_batch.next_actions) + assert torch.all(torch_batch2.returns_to_go == torch_batch.returns_to_go) + assert torch.all(torch_batch2.terminals == torch_batch.terminals) + assert torch.all(torch_batch2.intervals == torch_batch.intervals) + @pytest.mark.parametrize("batch_size", [32]) @pytest.mark.parametrize("length", [32]) @@ -397,6 +434,25 @@ def test_torch_trajectory_mini_batch( assert np.all(torch_batch.terminals.numpy() == batch.terminals) + torch_batch2 = TorchTrajectoryMiniBatch( + observations=torch.zeros_like(torch_batch.observations), + actions=torch.zeros_like(torch_batch.actions), + rewards=torch.zeros_like(torch_batch.rewards), + returns_to_go=torch.zeros_like(torch_batch.returns_to_go), + terminals=torch.zeros_like(torch_batch.terminals), + timesteps=torch.zeros_like(torch_batch.timesteps), + masks=torch.zeros_like(torch_batch.masks), + device=torch_batch.device, + ) + torch_batch2.copy_(torch_batch) + assert torch.all(torch_batch2.observations == torch_batch.observations) + assert torch.all(torch_batch2.actions == torch_batch.actions) + assert torch.all(torch_batch2.rewards == torch_batch.rewards) + assert torch.all(torch_batch2.returns_to_go == torch_batch.returns_to_go) + assert torch.all(torch_batch2.terminals == torch_batch.terminals) + assert torch.all(torch_batch2.timesteps == torch_batch.timesteps) + assert torch.all(torch_batch2.masks == torch_batch.masks) + def test_checkpointer() -> None: fc1 = torch.nn.Linear(100, 100) From 2c74ce93387cc025ddb2b0338c77cacf8217b8ef Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 16:38:27 +0900 Subject: [PATCH 11/15] Update python version in readthedocs --- .readthedocs.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 5cc47712..00956b64 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -2,7 +2,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.10" sphinx: builder: html configuration: docs/conf.py From 03888bec7064692b2cd3d3b1ffe277225f9bc096 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 16:43:43 +0900 Subject: [PATCH 12/15] Remove unnecessary change --- d3rlpy/types.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/d3rlpy/types.py b/d3rlpy/types.py index e1e50342..2d532c34 100644 --- a/d3rlpy/types.py +++ b/d3rlpy/types.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence, Type, TypeVar, Union +from typing import Any, Sequence, Union import gym import gymnasium @@ -20,7 +20,6 @@ "TorchObservation", "GymEnv", "OptimizerWrapperProto", - "assert_cast", ] @@ -43,11 +42,3 @@ class OptimizerWrapperProto(Protocol): @property def optim(self) -> Optimizer: raise NotImplementedError - - -T = TypeVar("T") - - -def assert_cast(obj_type: Type[T], obj: Any) -> T: - assert isinstance(obj, obj_type) - return obj From b41be5a97f46f203274c214c4bcc1f94ed7a40c3 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 18:30:47 +0900 Subject: [PATCH 13/15] Rename compile_graph to compiled --- d3rlpy/algos/qlearning/awac.py | 2 +- d3rlpy/algos/qlearning/bcq.py | 4 ++-- d3rlpy/algos/qlearning/bear.py | 2 +- d3rlpy/algos/qlearning/cal_ql.py | 2 +- d3rlpy/algos/qlearning/cql.py | 4 ++-- d3rlpy/algos/qlearning/crr.py | 2 +- d3rlpy/algos/qlearning/ddpg.py | 2 +- d3rlpy/algos/qlearning/dqn.py | 4 ++-- d3rlpy/algos/qlearning/iql.py | 2 +- d3rlpy/algos/qlearning/nfq.py | 2 +- d3rlpy/algos/qlearning/plas.py | 4 ++-- d3rlpy/algos/qlearning/rebrac.py | 2 +- d3rlpy/algos/qlearning/sac.py | 4 ++-- d3rlpy/algos/qlearning/td3.py | 2 +- d3rlpy/algos/qlearning/td3_plus_bc.py | 2 +- d3rlpy/algos/qlearning/torch/awac_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/bcq_impl.py | 10 +++++----- d3rlpy/algos/qlearning/torch/bear_impl.py | 8 ++++---- d3rlpy/algos/qlearning/torch/cql_impl.py | 8 ++++---- d3rlpy/algos/qlearning/torch/crr_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/ddpg_impl.py | 10 +++++----- d3rlpy/algos/qlearning/torch/dqn_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/iql_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/plas_impl.py | 10 +++++----- d3rlpy/algos/qlearning/torch/rebrac_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/sac_impl.py | 10 +++++----- d3rlpy/algos/qlearning/torch/td3_impl.py | 4 ++-- d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py | 4 ++-- d3rlpy/algos/transformer/decision_transformer.py | 4 ++-- .../transformer/torch/decision_transformer_impl.py | 8 ++++---- 30 files changed, 68 insertions(+), 68 deletions(-) diff --git a/d3rlpy/algos/qlearning/awac.py b/d3rlpy/algos/qlearning/awac.py index d0662795..701e4803 100644 --- a/d3rlpy/algos/qlearning/awac.py +++ b/d3rlpy/algos/qlearning/awac.py @@ -163,7 +163,7 @@ def inner_create_impl( tau=self._config.tau, lam=self._config.lam, n_action_samples=self._config.n_action_samples, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index 65a599f2..f89db8d6 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -270,7 +270,7 @@ def inner_create_impl( action_flexibility=self._config.action_flexibility, beta=self._config.beta, rl_start_step=self._config.rl_start_step, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) @@ -432,7 +432,7 @@ def inner_create_impl( gamma=self._config.gamma, action_flexibility=self._config.action_flexibility, beta=self._config.beta, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/bear.py b/d3rlpy/algos/qlearning/bear.py index 5aa00aad..5317abf7 100644 --- a/d3rlpy/algos/qlearning/bear.py +++ b/d3rlpy/algos/qlearning/bear.py @@ -276,7 +276,7 @@ def inner_create_impl( mmd_sigma=self._config.mmd_sigma, vae_kl_weight=self._config.vae_kl_weight, warmup_steps=self._config.warmup_steps, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cal_ql.py b/d3rlpy/algos/qlearning/cal_ql.py index df4782a1..9e3ac64e 100644 --- a/d3rlpy/algos/qlearning/cal_ql.py +++ b/d3rlpy/algos/qlearning/cal_ql.py @@ -179,7 +179,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index f1107752..de4fb79c 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -233,7 +233,7 @@ def inner_create_impl( n_action_samples=self._config.n_action_samples, soft_q_backup=self._config.soft_q_backup, max_q_backup=self._config.max_q_backup, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) @@ -348,7 +348,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, alpha=self._config.alpha, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/crr.py b/d3rlpy/algos/qlearning/crr.py index 26c37558..ea046274 100644 --- a/d3rlpy/algos/qlearning/crr.py +++ b/d3rlpy/algos/qlearning/crr.py @@ -204,7 +204,7 @@ def inner_create_impl( tau=self._config.tau, target_update_type=self._config.target_update_type, target_update_interval=self._config.target_update_interval, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/ddpg.py b/d3rlpy/algos/qlearning/ddpg.py index 15daba0b..abb40896 100644 --- a/d3rlpy/algos/qlearning/ddpg.py +++ b/d3rlpy/algos/qlearning/ddpg.py @@ -158,7 +158,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index 67294d5e..4e993b0c 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -109,7 +109,7 @@ def inner_create_impl( target_update_interval=self._config.target_update_interval, modules=modules, gamma=self._config.gamma, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) @@ -220,7 +220,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/iql.py b/d3rlpy/algos/qlearning/iql.py index 598f2ece..56aeb6f3 100644 --- a/d3rlpy/algos/qlearning/iql.py +++ b/d3rlpy/algos/qlearning/iql.py @@ -180,7 +180,7 @@ def inner_create_impl( expectile=self._config.expectile, weight_temp=self._config.weight_temp, max_weight=self._config.max_weight, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index b6346dfd..634f5c27 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -111,7 +111,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=1, gamma=self._config.gamma, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/plas.py b/d3rlpy/algos/qlearning/plas.py index 0b254e09..6f76f824 100644 --- a/d3rlpy/algos/qlearning/plas.py +++ b/d3rlpy/algos/qlearning/plas.py @@ -203,7 +203,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) @@ -386,7 +386,7 @@ def inner_create_impl( lam=self._config.lam, beta=self._config.beta, warmup_steps=self._config.warmup_steps, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/rebrac.py b/d3rlpy/algos/qlearning/rebrac.py index 9ab75854..b10521d8 100644 --- a/d3rlpy/algos/qlearning/rebrac.py +++ b/d3rlpy/algos/qlearning/rebrac.py @@ -171,7 +171,7 @@ def inner_create_impl( actor_beta=self._config.actor_beta, critic_beta=self._config.critic_beta, update_actor_interval=self._config.update_actor_interval, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/sac.py b/d3rlpy/algos/qlearning/sac.py index 4b12d42d..ef66c40a 100644 --- a/d3rlpy/algos/qlearning/sac.py +++ b/d3rlpy/algos/qlearning/sac.py @@ -195,7 +195,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, gamma=self._config.gamma, tau=self._config.tau, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) @@ -362,7 +362,7 @@ def inner_create_impl( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=self._config.target_update_interval, gamma=self._config.gamma, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3.py b/d3rlpy/algos/qlearning/td3.py index af44fc31..e61bf7b4 100644 --- a/d3rlpy/algos/qlearning/td3.py +++ b/d3rlpy/algos/qlearning/td3.py @@ -170,7 +170,7 @@ def inner_create_impl( target_smoothing_sigma=self._config.target_smoothing_sigma, target_smoothing_clip=self._config.target_smoothing_clip, update_actor_interval=self._config.update_actor_interval, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/td3_plus_bc.py b/d3rlpy/algos/qlearning/td3_plus_bc.py index 3ed57cf1..62758a40 100644 --- a/d3rlpy/algos/qlearning/td3_plus_bc.py +++ b/d3rlpy/algos/qlearning/td3_plus_bc.py @@ -163,7 +163,7 @@ def inner_create_impl( target_smoothing_clip=self._config.target_smoothing_clip, alpha=self._config.alpha, update_actor_interval=self._config.update_actor_interval, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index aad6b7b2..f1a46f5b 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -33,7 +33,7 @@ def __init__( tau: float, lam: float, n_action_samples: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -44,7 +44,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._lam = lam diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index bed815d6..896b1ab9 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -71,7 +71,7 @@ def __init__( action_flexibility: float, beta: float, rl_start_step: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -82,7 +82,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._lam = lam @@ -92,7 +92,7 @@ def __init__( self._rl_start_step = rl_start_step self._compute_imitator_grad = ( CudaGraphWrapper(self.compute_imitator_grad) - if compile_graph + if compiled else self.compute_imitator_grad ) @@ -256,7 +256,7 @@ def __init__( gamma: float, action_flexibility: float, beta: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -267,7 +267,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=target_update_interval, gamma=gamma, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._action_flexibility = action_flexibility diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 7aea6217..7adc6290 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -92,7 +92,7 @@ def __init__( mmd_sigma: float, vae_kl_weight: float, warmup_steps: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -103,7 +103,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._alpha_threshold = alpha_threshold @@ -117,12 +117,12 @@ def __init__( self._warmup_steps = warmup_steps self._compute_warmup_actor_grad = ( CudaGraphWrapper(self.compute_warmup_actor_grad) - if compile_graph + if compiled else self.compute_warmup_actor_grad ) self._compute_imitator_grad = ( CudaGraphWrapper(self.compute_imitator_grad) - if compile_graph + if compiled else self.compute_imitator_grad ) diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index e2c6753a..9d986d77 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -60,7 +60,7 @@ def __init__( n_action_samples: int, soft_q_backup: bool, max_q_backup: bool, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -71,7 +71,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._alpha_threshold = alpha_threshold @@ -247,7 +247,7 @@ def __init__( target_update_interval: int, gamma: float, alpha: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -258,7 +258,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, target_update_interval=target_update_interval, gamma=gamma, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 2e2ccd1b..eda3a8b3 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -55,7 +55,7 @@ def __init__( tau: float, target_update_type: str, target_update_interval: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -66,7 +66,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._beta = beta diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 7a427147..620b0eb9 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -73,7 +73,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -88,12 +88,12 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder self._compute_critic_grad = ( CudaGraphWrapper(self.compute_critic_grad) - if compile_graph + if compiled else self.compute_critic_grad ) self._compute_actor_grad = ( CudaGraphWrapper(self.compute_actor_grad) - if compile_graph + if compiled else self.compute_actor_grad ) hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) @@ -200,7 +200,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -211,7 +211,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) hard_sync(self._modules.targ_policy, self._modules.policy) diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index e7835ff6..0c6f5ba2 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -50,7 +50,7 @@ def __init__( targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, target_update_interval: int, gamma: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -65,7 +65,7 @@ def __init__( self._target_update_interval = target_update_interval self._compute_grad = ( CudaGraphWrapper(self.compute_grad) - if compile_graph + if compiled else self.compute_grad ) hard_sync(modules.targ_q_funcs, modules.q_funcs) diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 76c55086..77957d8c 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -51,7 +51,7 @@ def __init__( expectile: float, weight_temp: float, max_weight: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -62,7 +62,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._expectile = expectile diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 1d60010e..0fbe2bfc 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -53,7 +53,7 @@ def __init__( lam: float, beta: float, warmup_steps: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -64,7 +64,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._lam = lam @@ -72,7 +72,7 @@ def __init__( self._warmup_steps = warmup_steps self._compute_imitator_grad = ( CudaGraphWrapper(self.compute_imitator_grad) - if compile_graph + if compiled else self.compute_imitator_grad ) @@ -168,7 +168,7 @@ def __init__( lam: float, beta: float, warmup_steps: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -182,7 +182,7 @@ def __init__( lam=lam, beta=beta, warmup_steps=warmup_steps, - compile_graph=compile_graph, + compiled=compiled, device=device, ) diff --git a/d3rlpy/algos/qlearning/torch/rebrac_impl.py b/d3rlpy/algos/qlearning/torch/rebrac_impl.py index ba4af2be..85340f30 100644 --- a/d3rlpy/algos/qlearning/torch/rebrac_impl.py +++ b/d3rlpy/algos/qlearning/torch/rebrac_impl.py @@ -29,7 +29,7 @@ def __init__( actor_beta: float, critic_beta: float, update_actor_interval: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -43,7 +43,7 @@ def __init__( target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, update_actor_interval=update_actor_interval, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._actor_beta = actor_beta diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index b9521e63..58706147 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -65,7 +65,7 @@ def __init__( targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, gamma: float, tau: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -76,7 +76,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) @@ -160,7 +160,7 @@ def __init__( targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, target_update_interval: int, gamma: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -175,12 +175,12 @@ def __init__( self._target_update_interval = target_update_interval self._compute_critic_grad = ( CudaGraphWrapper(self.compute_critic_grad) - if compile_graph + if compiled else self.compute_critic_grad ) self._compute_actor_grad = ( CudaGraphWrapper(self.compute_actor_grad) - if compile_graph + if compiled else self.compute_actor_grad ) hard_sync(modules.targ_q_funcs, modules.q_funcs) diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 9d9b08fc..2dc6e513 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -27,7 +27,7 @@ def __init__( target_smoothing_sigma: float, target_smoothing_clip: float, update_actor_interval: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -38,7 +38,7 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, tau=tau, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._target_smoothing_sigma = target_smoothing_sigma diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 1d844b22..7614e8eb 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -33,7 +33,7 @@ def __init__( target_smoothing_clip: float, alpha: float, update_actor_interval: int, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -47,7 +47,7 @@ def __init__( target_smoothing_sigma=target_smoothing_sigma, target_smoothing_clip=target_smoothing_clip, update_actor_interval=update_actor_interval, - compile_graph=compile_graph, + compiled=compiled, device=device, ) self._alpha = alpha diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index 870a679a..8a822e4e 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -121,7 +121,7 @@ def inner_create_impl( action_size=action_size, modules=modules, device=self._device, - compile_graph=self.compiled, + compiled=self.compiled, ) def get_action_type(self) -> ActionSpace: @@ -235,7 +235,7 @@ def inner_create_impl( warmup_tokens=self._config.warmup_tokens, final_tokens=self._config.final_tokens, initial_learning_rate=self._config.learning_rate, - compile_graph=self.compiled, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index 2b03c0b2..a77dfda1 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -42,13 +42,13 @@ def __init__( observation_shape: Shape, action_size: int, modules: Modules, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__(observation_shape, action_size, modules, device) self._compute_grad = ( CudaGraphWrapper(self.compute_grad) - if compile_graph + if compiled else self.compute_grad ) @@ -109,7 +109,7 @@ def __init__( warmup_tokens: int, final_tokens: int, initial_learning_rate: float, - compile_graph: bool, + compiled: bool, device: str, ): super().__init__( @@ -123,7 +123,7 @@ def __init__( self._initial_learning_rate = initial_learning_rate self._compute_grad = ( CudaGraphWrapper(self.compute_grad) - if compile_graph + if compiled else self.compute_grad ) # TODO: Include stateful information in checkpoint. From c6e6cd4ccbaf09f222016c3534b65a04e3660933 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 18:45:27 +0900 Subject: [PATCH 14/15] Support BC --- d3rlpy/algos/qlearning/bc.py | 8 ++++++-- d3rlpy/algos/qlearning/torch/bc_impl.py | 24 ++++++++++++++++++------ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/d3rlpy/algos/qlearning/bc.py b/d3rlpy/algos/qlearning/bc.py index adb886ee..da9cc275 100644 --- a/d3rlpy/algos/qlearning/bc.py +++ b/d3rlpy/algos/qlearning/bc.py @@ -49,6 +49,7 @@ class BCConfig(LearnableConfig): observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. action_scaler (d3rlpy.preprocessing.ActionScaler): Action preprocessor. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 100 @@ -95,7 +96,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( imitator.named_modules(), lr=self._config.learning_rate, - compiled=False, + compiled=self.compiled, ) modules = BCModules(optim=optim, imitator=imitator) @@ -105,6 +106,7 @@ def inner_create_impl( action_size=action_size, modules=modules, policy_type=self._config.policy_type, + compiled=self.compiled, device=self._device, ) @@ -139,6 +141,7 @@ class DiscreteBCConfig(LearnableConfig): beta (float): Reguralization factor. observation_scaler (d3rlpy.preprocessing.ObservationScaler): Observation preprocessor. + compile_graph (bool): Flag to enable JIT compilation and CUDAGraph. """ batch_size: int = 100 @@ -172,7 +175,7 @@ def inner_create_impl( optim = self._config.optim_factory.create( imitator.named_modules(), lr=self._config.learning_rate, - compiled=False, + compiled=self.compiled, ) modules = DiscreteBCModules(optim=optim, imitator=imitator) @@ -182,6 +185,7 @@ def inner_create_impl( action_size=action_size, modules=modules, beta=self._config.beta, + compiled=self.compiled, device=self._device, ) diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 8680ef5b..85b1d5f6 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -1,6 +1,6 @@ import dataclasses from abc import ABCMeta, abstractmethod -from typing import Dict, Union +from typing import Callable, Dict, Union import torch from torch.optim import Optimizer @@ -18,7 +18,7 @@ compute_stochastic_imitation_loss, ) from ....optimizers import OptimizerWrapper -from ....torch_utility import Modules, TorchMiniBatch +from ....torch_utility import CudaGraphWrapper, Modules, TorchMiniBatch from ....types import Shape, TorchObservation from ..base import QLearningAlgoImplBase @@ -32,12 +32,14 @@ class BCBaseModules(Modules): class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta): _modules: BCBaseModules + _compute_imitator_grad: Callable[[TorchMiniBatch], ImitationLoss] def __init__( self, observation_shape: Shape, action_size: int, modules: BCBaseModules, + compiled: bool, device: str, ): super().__init__( @@ -46,15 +48,21 @@ def __init__( modules=modules, device=device, ) + self._compute_imitator_grad = ( + CudaGraphWrapper(self.compute_imitator_grad) + if compiled + else self.compute_imitator_grad + ) - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def compute_imitator_grad(self, batch: TorchMiniBatch) -> ImitationLoss: self._modules.optim.zero_grad() - loss = self.compute_loss(batch.observations, batch.actions) - loss.loss.backward() - self._modules.optim.step() + return loss + def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + loss = self._compute_imitator_grad(batch) + self._modules.optim.step() return asdict_as_float(loss) @abstractmethod @@ -92,12 +100,14 @@ def __init__( action_size: int, modules: BCModules, policy_type: str, + compiled: bool, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, modules=modules, + compiled=compiled, device=device, ) self._policy_type = policy_type @@ -145,12 +155,14 @@ def __init__( action_size: int, modules: DiscreteBCModules, beta: float, + compiled: bool, device: str, ): super().__init__( observation_shape=observation_shape, action_size=action_size, modules=modules, + compiled=compiled, device=device, ) self._beta = beta From 8df3e4689a3722b7c02403b01b369fc152b380f5 Mon Sep 17 00:00:00 2001 From: takuseno Date: Sun, 3 Nov 2024 18:53:53 +0900 Subject: [PATCH 15/15] Add compile option to reproduction scripts --- reproductions/finetuning/awac_finetune.py | 2 ++ reproductions/finetuning/cal_ql_finetune.py | 2 ++ reproductions/finetuning/iql_finetune.py | 2 ++ reproductions/offline/awac.py | 2 ++ reproductions/offline/bcq.py | 2 ++ reproductions/offline/bear.py | 2 ++ reproductions/offline/cql.py | 2 ++ reproductions/offline/crr.py | 2 ++ reproductions/offline/decision_transformer.py | 2 ++ reproductions/offline/discrete_bcq.py | 2 ++ reproductions/offline/discrete_cql.py | 2 ++ reproductions/offline/discrete_decision_transformer.py | 2 ++ reproductions/offline/iql.py | 2 ++ reproductions/offline/nfq.py | 2 ++ reproductions/offline/plas.py | 2 ++ reproductions/offline/plas_with_perturbation.py | 2 ++ reproductions/offline/qr_dqn.py | 2 ++ reproductions/offline/rebrac.py | 2 ++ reproductions/offline/sac.py | 2 ++ reproductions/offline/td3.py | 2 ++ reproductions/offline/td3_plus_bc.py | 2 ++ reproductions/online/double_dqn_online.py | 2 ++ reproductions/online/dqn_online.py | 2 ++ reproductions/online/iqn_online.py | 2 ++ reproductions/online/qr_dqn_online.py | 2 ++ reproductions/online/sac_online.py | 2 ++ 26 files changed, 52 insertions(+) diff --git a/reproductions/finetuning/awac_finetune.py b/reproductions/finetuning/awac_finetune.py index 6cd64375..970cd361 100644 --- a/reproductions/finetuning/awac_finetune.py +++ b/reproductions/finetuning/awac_finetune.py @@ -9,6 +9,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_minari(args.dataset) @@ -30,6 +31,7 @@ def main() -> None: batch_size=1024, lam=1.0, reward_scaler=reward_scaler, + compile_graph=args.compile, ).create(device=args.gpu) awac.fit( diff --git a/reproductions/finetuning/cal_ql_finetune.py b/reproductions/finetuning/cal_ql_finetune.py index 35a492d2..e6498703 100644 --- a/reproductions/finetuning/cal_ql_finetune.py +++ b/reproductions/finetuning/cal_ql_finetune.py @@ -11,6 +11,7 @@ def main() -> None: ) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # sparse reward setup requires special treatment for failure trajectories @@ -51,6 +52,7 @@ def main() -> None: alpha_threshold=0.8, reward_scaler=reward_scaler, max_q_backup=True, + compile_graph=args.compile, ).create(device=args.gpu) # pretraining diff --git a/reproductions/finetuning/iql_finetune.py b/reproductions/finetuning/iql_finetune.py index 942390a4..1dc49da8 100644 --- a/reproductions/finetuning/iql_finetune.py +++ b/reproductions/finetuning/iql_finetune.py @@ -10,6 +10,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="antmaze-umaze-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_minari(args.dataset) @@ -34,6 +35,7 @@ def main() -> None: max_weight=100.0, expectile=0.9, # hyperparameter for antmaze reward_scaler=reward_scaler, + compile_graph=args.compile, ).create(device=args.gpu) # pretraining diff --git a/reproductions/offline/awac.py b/reproductions/offline/awac.py index 3e39d238..345121fe 100644 --- a/reproductions/offline/awac.py +++ b/reproductions/offline/awac.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -27,6 +28,7 @@ def main() -> None: critic_encoder_factory=encoder, batch_size=1024, lam=1.0, + compile_graph=args.compile, ).create(args.gpu) awac.fit( diff --git a/reproductions/offline/bcq.py b/reproductions/offline/bcq.py index 434e6fad..1101d6bb 100644 --- a/reproductions/offline/bcq.py +++ b/reproductions/offline/bcq.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -30,6 +31,7 @@ def main() -> None: lam=0.75, action_flexibility=0.05, n_action_samples=100, + compile_graph=args.compile, ).create(args.gpu) bcq.fit( diff --git a/reproductions/offline/bear.py b/reproductions/offline/bear.py index 184ca194..58ce8644 100644 --- a/reproductions/offline/bear.py +++ b/reproductions/offline/bear.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -39,6 +40,7 @@ def main() -> None: n_target_samples=10, n_action_samples=100, warmup_steps=40000, + compile_graph=args.compile, ).create(device=args.gpu) bear.fit( diff --git a/reproductions/offline/cql.py b/reproductions/offline/cql.py index 3fe8c181..77dc6d6b 100644 --- a/reproductions/offline/cql.py +++ b/reproductions/offline/cql.py @@ -9,6 +9,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -36,6 +37,7 @@ def main() -> None: n_action_samples=10, alpha_threshold=10, conservative_weight=conservative_weight, + compile_graph=args.compile, ).create(device=args.gpu) cql.fit( diff --git a/reproductions/offline/crr.py b/reproductions/offline/crr.py index 4401af80..56977573 100644 --- a/reproductions/offline/crr.py +++ b/reproductions/offline/crr.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -23,6 +24,7 @@ def main() -> None: weight_type="binary", advantage_type="mean", target_update_type="soft", + compile_graph=args.compile, ).create(device=args.gpu) crr.fit( diff --git a/reproductions/offline/decision_transformer.py b/reproductions/offline/decision_transformer.py index f3849f86..cceea537 100644 --- a/reproductions/offline/decision_transformer.py +++ b/reproductions/offline/decision_transformer.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -46,6 +47,7 @@ def main() -> None: num_heads=1, num_layers=3, max_timestep=1000, + compile_graph=args.compile, ).create(device=args.gpu) dt.fit( diff --git a/reproductions/offline/discrete_bcq.py b/reproductions/offline/discrete_bcq.py index 686a84c0..f74ec222 100644 --- a/reproductions/offline/discrete_bcq.py +++ b/reproductions/offline/discrete_bcq.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--game", type=str, default="breakout") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # fix seed @@ -34,6 +35,7 @@ def main() -> None: reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), action_flexibility=0.3, beta=0.01, + compile_graph=args.compile, ).create(device=args.gpu) env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) diff --git a/reproductions/offline/discrete_cql.py b/reproductions/offline/discrete_cql.py index cfb486b2..e3afa0f8 100644 --- a/reproductions/offline/discrete_cql.py +++ b/reproductions/offline/discrete_cql.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--game", type=str, default="breakout") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() d3rlpy.seed(args.seed) @@ -32,6 +33,7 @@ def main() -> None: observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), target_update_interval=2000, reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), + compile_graph=args.compile, ).create(device=args.gpu) env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) diff --git a/reproductions/offline/discrete_decision_transformer.py b/reproductions/offline/discrete_decision_transformer.py index c0545c94..dc80b966 100644 --- a/reproductions/offline/discrete_decision_transformer.py +++ b/reproductions/offline/discrete_decision_transformer.py @@ -9,6 +9,7 @@ def main() -> None: parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) parser.add_argument("--pre-stack", action="store_true") + parser.add_argument("--compile", action="store_true") args = parser.parse_args() d3rlpy.seed(args.seed) @@ -70,6 +71,7 @@ def main() -> None: observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), max_timestep=max_timestep, position_encoding_type=d3rlpy.PositionEncodingType.GLOBAL, + compile_graph=args.compile, ).create(device=args.gpu) n_steps_per_epoch = dataset.transition_count // batch_size diff --git a/reproductions/offline/iql.py b/reproductions/offline/iql.py index 78136d83..abdc26ba 100644 --- a/reproductions/offline/iql.py +++ b/reproductions/offline/iql.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -33,6 +34,7 @@ def main() -> None: max_weight=100.0, expectile=0.7, reward_scaler=reward_scaler, + compile_graph=args.compile, ).create(device=args.gpu) iql.fit( diff --git a/reproductions/offline/nfq.py b/reproductions/offline/nfq.py index eb2fec1e..d9484eff 100644 --- a/reproductions/offline/nfq.py +++ b/reproductions/offline/nfq.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--game", type=str, default="breakout") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # fix seed @@ -28,6 +29,7 @@ def main() -> None: batch_size=32, observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), + compile_graph=args.compile, ).create(device=args.gpu) env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) diff --git a/reproductions/offline/plas.py b/reproductions/offline/plas.py index a1d55ede..106d0fa2 100644 --- a/reproductions/offline/plas.py +++ b/reproductions/offline/plas.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -32,6 +33,7 @@ def main() -> None: batch_size=100, lam=1.0, warmup_steps=500000, + compile_graph=args.compile, ).create(device=args.gpu) plas.fit( diff --git a/reproductions/offline/plas_with_perturbation.py b/reproductions/offline/plas_with_perturbation.py index cc15e9eb..25cf4307 100644 --- a/reproductions/offline/plas_with_perturbation.py +++ b/reproductions/offline/plas_with_perturbation.py @@ -23,6 +23,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -48,6 +49,7 @@ def main() -> None: lam=1.0, warmup_steps=500000, action_flexibility=ACTION_FLEXIBILITY[args.dataset], + compile_graph=args.compile, ).create(device=args.gpu) plas.fit( diff --git a/reproductions/offline/qr_dqn.py b/reproductions/offline/qr_dqn.py index a1765ab1..1b9568e8 100644 --- a/reproductions/offline/qr_dqn.py +++ b/reproductions/offline/qr_dqn.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--game", type=str, default="breakout") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # fix seed @@ -32,6 +33,7 @@ def main() -> None: observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), target_update_interval=2000, reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), + compile_graph=args.compile, ).create(device=args.gpu) env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) diff --git a/reproductions/offline/rebrac.py b/reproductions/offline/rebrac.py index de77faab..3098c8dc 100644 --- a/reproductions/offline/rebrac.py +++ b/reproductions/offline/rebrac.py @@ -29,6 +29,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -62,6 +63,7 @@ def main() -> None: actor_beta=actor_beta, critic_beta=critic_beta, observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), + compile_graph=args.compile, ).create(device=args.gpu) rebrac.fit( diff --git a/reproductions/offline/sac.py b/reproductions/offline/sac.py index 8070c520..9649e274 100644 --- a/reproductions/offline/sac.py +++ b/reproductions/offline/sac.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -21,6 +22,7 @@ def main() -> None: critic_learning_rate=3e-4, temp_learning_rate=3e-4, batch_size=256, + compile_graph=args.compile, ).create(device=args.gpu) sac.fit( diff --git a/reproductions/offline/td3.py b/reproductions/offline/td3.py index d91fb190..7aeefe66 100644 --- a/reproductions/offline/td3.py +++ b/reproductions/offline/td3.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -23,6 +24,7 @@ def main() -> None: target_smoothing_sigma=0.2, target_smoothing_clip=0.5, update_actor_interval=2, + compile_graph=args.compile, ).create(device=args.gpu) td3.fit( diff --git a/reproductions/offline/td3_plus_bc.py b/reproductions/offline/td3_plus_bc.py index 60d81f30..953c93d8 100644 --- a/reproductions/offline/td3_plus_bc.py +++ b/reproductions/offline/td3_plus_bc.py @@ -8,6 +8,7 @@ def main() -> None: parser.add_argument("--dataset", type=str, default="hopper-medium-v0") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", type=int) + parser.add_argument("--compile", action="store_true") args = parser.parse_args() dataset, env = d3rlpy.datasets.get_dataset(args.dataset) @@ -25,6 +26,7 @@ def main() -> None: alpha=2.5, update_actor_interval=2, observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), + compile_graph=args.compile, ).create(device=args.gpu) td3.fit( diff --git a/reproductions/online/double_dqn_online.py b/reproductions/online/double_dqn_online.py index 477eece3..31f5215a 100644 --- a/reproductions/online/double_dqn_online.py +++ b/reproductions/online/double_dqn_online.py @@ -10,6 +10,7 @@ def main() -> None: parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", action="store_true") + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # get wrapped atari environment @@ -28,6 +29,7 @@ def main() -> None: optim_factory=d3rlpy.optimizers.RMSpropFactory(), target_update_interval=10000, observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), + compile_graph=args.compile, ).create(device=args.gpu) # replay buffer for experience replay diff --git a/reproductions/online/dqn_online.py b/reproductions/online/dqn_online.py index b28688df..815dffe1 100644 --- a/reproductions/online/dqn_online.py +++ b/reproductions/online/dqn_online.py @@ -10,6 +10,7 @@ def main() -> None: parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", action="store_true") + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # get wrapped atari environment @@ -28,6 +29,7 @@ def main() -> None: optim_factory=d3rlpy.optimizers.RMSpropFactory(), target_update_interval=10000 // 4, observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), + compile_graph=args.compile, ).create(device=args.gpu) # replay buffer for experience replay diff --git a/reproductions/online/iqn_online.py b/reproductions/online/iqn_online.py index e00f5074..93099134 100644 --- a/reproductions/online/iqn_online.py +++ b/reproductions/online/iqn_online.py @@ -10,6 +10,7 @@ def main() -> None: parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", action="store_true") + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # get wrapped atari environment @@ -29,6 +30,7 @@ def main() -> None: target_update_interval=10000 // 4, q_func_factory=d3rlpy.models.IQNQFunctionFactory(), observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), + compile_graph=args.compile, ).create(device=args.gpu) # replay buffer for experience replay diff --git a/reproductions/online/qr_dqn_online.py b/reproductions/online/qr_dqn_online.py index 17baa45f..bf76ed6e 100644 --- a/reproductions/online/qr_dqn_online.py +++ b/reproductions/online/qr_dqn_online.py @@ -10,6 +10,7 @@ def main() -> None: parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", action="store_true") + parser.add_argument("--compile", action="store_true") args = parser.parse_args() # get wrapped atari environment @@ -31,6 +32,7 @@ def main() -> None: n_quantiles=200 ), observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), + compile_graph=args.compile, ).create(device=args.gpu) # replay buffer for experience replay diff --git a/reproductions/online/sac_online.py b/reproductions/online/sac_online.py index 41ac79ad..0bb99cff 100644 --- a/reproductions/online/sac_online.py +++ b/reproductions/online/sac_online.py @@ -10,6 +10,7 @@ def main() -> None: parser.add_argument("--env", type=str, default="Hopper-v2") parser.add_argument("--seed", type=int, default=1) parser.add_argument("--gpu", action="store_true") + parser.add_argument("--compile", action="store_true") args = parser.parse_args() env = gym.make(args.env) @@ -26,6 +27,7 @@ def main() -> None: actor_learning_rate=3e-4, critic_learning_rate=3e-4, temp_learning_rate=3e-4, + compile_graph=args.compile, ).create(device=args.gpu) # replay buffer for experience replay