diff --git a/d3rlpy/algos/qlearning/bcq.py b/d3rlpy/algos/qlearning/bcq.py index fba40c0a..6224cf82 100644 --- a/d3rlpy/algos/qlearning/bcq.py +++ b/d3rlpy/algos/qlearning/bcq.py @@ -427,11 +427,13 @@ def inner_create_impl( # build functional components updater = DQNUpdater( - modules=modules, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, dqn_loss_fn=DiscreteBCQLossFn( - modules=modules, q_func_forwarder=q_func_forwarder, targ_q_func_forwarder=targ_q_func_forwarder, + imitator=imitator, gamma=self._config.gamma, beta=self._config.beta, ), @@ -439,8 +441,8 @@ def inner_create_impl( compiled=self.compiled, ) action_sampler = DiscreteBCQActionSampler( - modules=modules, q_func_forwarder=q_func_forwarder, + imitator=imitator, action_flexibility=self._config.action_flexibility, ) value_predictor = DQNValuePredictor(q_func_forwarder) diff --git a/d3rlpy/algos/qlearning/cql.py b/d3rlpy/algos/qlearning/cql.py index 53fceacb..170b94f0 100644 --- a/d3rlpy/algos/qlearning/cql.py +++ b/d3rlpy/algos/qlearning/cql.py @@ -347,7 +347,9 @@ def inner_create_impl( # build functional components updater = DQNUpdater( - modules=modules, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, dqn_loss_fn=DiscreteCQLLossFn( action_size=action_size, q_func_forwarder=q_func_forwarder, diff --git a/d3rlpy/algos/qlearning/dqn.py b/d3rlpy/algos/qlearning/dqn.py index f2bc3a38..d587f6e6 100644 --- a/d3rlpy/algos/qlearning/dqn.py +++ b/d3rlpy/algos/qlearning/dqn.py @@ -111,7 +111,9 @@ def inner_create_impl( # build functional components updater = DQNUpdater( - modules=modules, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, dqn_loss_fn=DQNLossFn( q_func_forwarder=forwarder, targ_q_func_forwarder=targ_forwarder, @@ -239,7 +241,9 @@ def inner_create_impl( # build functional components updater = DQNUpdater( - modules=modules, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, dqn_loss_fn=DoubleDQNLossFn( q_func_forwarder=forwarder, targ_q_func_forwarder=targ_forwarder, diff --git a/d3rlpy/algos/qlearning/nfq.py b/d3rlpy/algos/qlearning/nfq.py index 9dc1d637..73d78677 100644 --- a/d3rlpy/algos/qlearning/nfq.py +++ b/d3rlpy/algos/qlearning/nfq.py @@ -116,7 +116,9 @@ def inner_create_impl( gamma=self._config.gamma, ) updater = DQNUpdater( - modules=modules, + q_funcs=q_funcs, + targ_q_funcs=targ_q_funcs, + optim=optim, dqn_loss_fn=loss_fn, target_update_interval=1, compiled=self.compiled, diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index 67823b81..dc4bb300 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -245,9 +245,9 @@ class DiscreteBCQLoss(DQNLoss): class DiscreteBCQLossFn(DoubleDQNLossFn): def __init__( self, - modules: DiscreteBCQModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + imitator: CategoricalPolicy, gamma: float, beta: float, ): @@ -256,13 +256,13 @@ def __init__( targ_q_func_forwarder=targ_q_func_forwarder, gamma=gamma, ) - self._modules = modules + self._imitator = imitator self._beta = beta def __call__(self, batch: TorchMiniBatch) -> DiscreteBCQLoss: td_loss = super().__call__(batch).loss imitator_loss = compute_discrete_imitation_loss( - policy=self._modules.imitator, + policy=self._imitator, x=batch.observations, action=batch.actions.long(), beta=self._beta, @@ -275,16 +275,16 @@ def __call__(self, batch: TorchMiniBatch) -> DiscreteBCQLoss: class DiscreteBCQActionSampler(ActionSampler): def __init__( self, - modules: DiscreteBCQModules, q_func_forwarder: DiscreteEnsembleQFunctionForwarder, + imitator: CategoricalPolicy, action_flexibility: float, ): - self._modules = modules self._q_func_forwarder = q_func_forwarder + self._imitator = imitator self._action_flexibility = action_flexibility def __call__(self, x: TorchObservation) -> torch.Tensor: - dist = self._modules.imitator(x) + dist = self._imitator(x) log_probs = F.log_softmax(dist.logits, dim=1) ratio = log_probs - log_probs.max(dim=1, keepdim=True).values mask = (ratio > math.log(self._action_flexibility)).float() diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index 8db4a569..fa4f627e 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -105,25 +105,29 @@ def __call__(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor: class DQNUpdater(Updater): def __init__( self, - modules: DQNModules, + q_funcs: nn.ModuleList, + targ_q_funcs: nn.ModuleList, + optim: OptimizerWrapper, dqn_loss_fn: DQNLossFn, target_update_interval: int, compiled: bool, ): - self._modules = modules + self._q_funcs = q_funcs + self._targ_q_funcs = targ_q_funcs + self._optim = optim self._dqn_loss_fn = dqn_loss_fn self._target_update_interval = target_update_interval self._compute_grad = CudaGraphWrapper(self.compute_grad) if compiled else self.compute_grad def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss: - self._modules.optim.zero_grad() + self._optim.zero_grad() loss = self._dqn_loss_fn(batch) loss.loss.backward() return loss def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]: loss = self._compute_grad(batch) - self._modules.optim.step() + self._optim.step() if grad_step % self._target_update_interval == 0: - hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) + hard_sync(self._targ_q_funcs, self._q_funcs) return asdict_as_float(loss)