diff --git a/ding/policy/madqn.py b/ding/policy/madqn.py index 50ceb40b0f..185992af64 100644 --- a/ding/policy/madqn.py +++ b/ding/policy/madqn.py @@ -198,7 +198,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: td_error_per_sample = [] for t in range(self._cfg.collect.unroll_len): v_data = v_nstep_td_data( - total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma + total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], None ) # calculate v_nstep_td critic_loss loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep) @@ -231,8 +231,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]: cooperation_loss_all = [] for t in range(self._cfg.collect.unroll_len): v_data = v_nstep_td_data( - cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t], - data['weight'], self._gamma + cooperation_total_q[t], + cooperation_target_total_q[t], + data['reward'][t], + data['done'][t], + data['weight'], + None, ) cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep) cooperation_loss_all.append(cooperation_loss)