Skip to content

Commit

Permalink
Remove modules dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 17, 2024
1 parent 9a77f40 commit 9ca1912
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 18 deletions.
8 changes: 5 additions & 3 deletions d3rlpy/algos/qlearning/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,20 +427,22 @@ def inner_create_impl(

# build functional components
updater = DQNUpdater(
modules=modules,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
optim=optim,
dqn_loss_fn=DiscreteBCQLossFn(
modules=modules,
q_func_forwarder=q_func_forwarder,
targ_q_func_forwarder=targ_q_func_forwarder,
imitator=imitator,
gamma=self._config.gamma,
beta=self._config.beta,
),
target_update_interval=self._config.target_update_interval,
compiled=self.compiled,
)
action_sampler = DiscreteBCQActionSampler(
modules=modules,
q_func_forwarder=q_func_forwarder,
imitator=imitator,
action_flexibility=self._config.action_flexibility,
)
value_predictor = DQNValuePredictor(q_func_forwarder)
Expand Down
4 changes: 3 additions & 1 deletion d3rlpy/algos/qlearning/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ def inner_create_impl(

# build functional components
updater = DQNUpdater(
modules=modules,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
optim=optim,
dqn_loss_fn=DiscreteCQLLossFn(
action_size=action_size,
q_func_forwarder=q_func_forwarder,
Expand Down
8 changes: 6 additions & 2 deletions d3rlpy/algos/qlearning/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def inner_create_impl(

# build functional components
updater = DQNUpdater(
modules=modules,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
optim=optim,
dqn_loss_fn=DQNLossFn(
q_func_forwarder=forwarder,
targ_q_func_forwarder=targ_forwarder,
Expand Down Expand Up @@ -239,7 +241,9 @@ def inner_create_impl(

# build functional components
updater = DQNUpdater(
modules=modules,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
optim=optim,
dqn_loss_fn=DoubleDQNLossFn(
q_func_forwarder=forwarder,
targ_q_func_forwarder=targ_forwarder,
Expand Down
4 changes: 3 additions & 1 deletion d3rlpy/algos/qlearning/nfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def inner_create_impl(
gamma=self._config.gamma,
)
updater = DQNUpdater(
modules=modules,
q_funcs=q_funcs,
targ_q_funcs=targ_q_funcs,
optim=optim,
dqn_loss_fn=loss_fn,
target_update_interval=1,
compiled=self.compiled,
Expand Down
12 changes: 6 additions & 6 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ class DiscreteBCQLoss(DQNLoss):
class DiscreteBCQLossFn(DoubleDQNLossFn):
def __init__(
self,
modules: DiscreteBCQModules,
q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
targ_q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
imitator: CategoricalPolicy,
gamma: float,
beta: float,
):
Expand All @@ -256,13 +256,13 @@ def __init__(
targ_q_func_forwarder=targ_q_func_forwarder,
gamma=gamma,
)
self._modules = modules
self._imitator = imitator
self._beta = beta

def __call__(self, batch: TorchMiniBatch) -> DiscreteBCQLoss:
td_loss = super().__call__(batch).loss
imitator_loss = compute_discrete_imitation_loss(
policy=self._modules.imitator,
policy=self._imitator,
x=batch.observations,
action=batch.actions.long(),
beta=self._beta,
Expand All @@ -275,16 +275,16 @@ def __call__(self, batch: TorchMiniBatch) -> DiscreteBCQLoss:
class DiscreteBCQActionSampler(ActionSampler):
def __init__(
self,
modules: DiscreteBCQModules,
q_func_forwarder: DiscreteEnsembleQFunctionForwarder,
imitator: CategoricalPolicy,
action_flexibility: float,
):
self._modules = modules
self._q_func_forwarder = q_func_forwarder
self._imitator = imitator
self._action_flexibility = action_flexibility

def __call__(self, x: TorchObservation) -> torch.Tensor:
dist = self._modules.imitator(x)
dist = self._imitator(x)
log_probs = F.log_softmax(dist.logits, dim=1)
ratio = log_probs - log_probs.max(dim=1, keepdim=True).values
mask = (ratio > math.log(self._action_flexibility)).float()
Expand Down
14 changes: 9 additions & 5 deletions d3rlpy/algos/qlearning/torch/dqn_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,29 @@ def __call__(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor:
class DQNUpdater(Updater):
def __init__(
self,
modules: DQNModules,
q_funcs: nn.ModuleList,
targ_q_funcs: nn.ModuleList,
optim: OptimizerWrapper,
dqn_loss_fn: DQNLossFn,
target_update_interval: int,
compiled: bool,
):
self._modules = modules
self._q_funcs = q_funcs
self._targ_q_funcs = targ_q_funcs
self._optim = optim
self._dqn_loss_fn = dqn_loss_fn
self._target_update_interval = target_update_interval
self._compute_grad = CudaGraphWrapper(self.compute_grad) if compiled else self.compute_grad

def compute_grad(self, batch: TorchMiniBatch) -> DQNLoss:
self._modules.optim.zero_grad()
self._optim.zero_grad()
loss = self._dqn_loss_fn(batch)
loss.loss.backward()
return loss

def __call__(self, batch: TorchMiniBatch, grad_step: int) -> dict[str, float]:
loss = self._compute_grad(batch)
self._modules.optim.step()
self._optim.step()
if grad_step % self._target_update_interval == 0:
hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs)
hard_sync(self._targ_q_funcs, self._q_funcs)
return asdict_as_float(loss)

0 comments on commit 9ca1912

Please sign in to comment.