From c7c3bac06bb8b5df22610bab370d30e9caf4e295 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 24 Apr 2024 12:19:48 +0800 Subject: [PATCH] fix(nyz): fix marl nstep td compatibility bug --- ding/policy/madqn.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ding/policy/madqn.py b/ding/policy/madqn.py index 50ceb40b0f..185992af64 100644 --- a/ding/policy/madqn.py +++ b/ding/policy/madqn.py @@ -198,7 +198,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: td_error_per_sample = [] for t in range(self._cfg.collect.unroll_len): v_data = v_nstep_td_data( - total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma + total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], None ) # calculate v_nstep_td critic_loss loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep) @@ -231,8 +231,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: cooperation_loss_all = [] for t in range(self._cfg.collect.unroll_len): v_data = v_nstep_td_data( - cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t], - data['weight'], self._gamma + cooperation_total_q[t], + cooperation_target_total_q[t], + data['reward'][t], + data['done'][t], + data['weight'], + None, ) cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep) cooperation_loss_all.append(cooperation_loss)