diff --git a/6.categorical dqn.ipynb b/6.categorical dqn.ipynb index f99f66e..6d21281 100644 --- a/6.categorical dqn.ipynb +++ b/6.categorical dqn.ipynb @@ -198,8 +198,8 @@ " delta_z = float(Vmax - Vmin) / (num_atoms - 1)\n", " support = torch.linspace(Vmin, Vmax, num_atoms)\n", " \n", - " next_dist = target_model(next_state).data.cpu() * support\n", - " next_action = next_dist.sum(2).max(1)[1]\n", + " next_dist = target_model(next_state).data.cpu()\n", + " next_action = (next_dist * support).sum(2).max(1)[1]\n", " next_action = next_action.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, next_dist.size(2))\n", " next_dist = next_dist.gather(1, next_action).squeeze(1)\n", " \n", diff --git a/7.rainbow dqn.ipynb b/7.rainbow dqn.ipynb index 8d2ad10..8c4e88e 100644 --- a/7.rainbow dqn.ipynb +++ b/7.rainbow dqn.ipynb @@ -186,8 +186,8 @@ " delta_z = float(Vmax - Vmin) / (num_atoms - 1)\n", " support = torch.linspace(Vmin, Vmax, num_atoms)\n", " \n", - " next_dist = target_model(next_state).data.cpu() * support\n", - " next_action = next_dist.sum(2).max(1)[1]\n", + " next_dist = target_model(next_state).data.cpu()\n", + " next_action = (next_dist * support).sum(2).max(1)[1]\n", " next_action = next_action.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, next_dist.size(2))\n", " next_dist = next_dist.gather(1, next_action).squeeze(1)\n", " \n",