From e176f7918ea7909c17467b9bf3cb764f5af0880a Mon Sep 17 00:00:00 2001 From: Nako Sung Date: Mon, 18 Feb 2019 18:53:32 +0900 Subject: [PATCH 1/2] Fix error in proj_dist --- 6.categorical dqn.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/6.categorical dqn.ipynb b/6.categorical dqn.ipynb index f99f66e..6d21281 100644 --- a/6.categorical dqn.ipynb +++ b/6.categorical dqn.ipynb @@ -198,8 +198,8 @@ " delta_z = float(Vmax - Vmin) / (num_atoms - 1)\n", " support = torch.linspace(Vmin, Vmax, num_atoms)\n", " \n", - " next_dist = target_model(next_state).data.cpu() * support\n", - " next_action = next_dist.sum(2).max(1)[1]\n", + " next_dist = target_model(next_state).data.cpu()\n", + " next_action = (next_dist * support).sum(2).max(1)[1]\n", " next_action = next_action.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, next_dist.size(2))\n", " next_dist = next_dist.gather(1, next_action).squeeze(1)\n", " \n", From 9bd074477053ed9da4f6a4fb8c689da680dcffdb Mon Sep 17 00:00:00 2001 From: Nako Sung Date: Mon, 18 Feb 2019 18:55:37 +0900 Subject: [PATCH 2/2] Fix error in proj_dist --- 7.rainbow dqn.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/7.rainbow dqn.ipynb b/7.rainbow dqn.ipynb index 8d2ad10..8c4e88e 100644 --- a/7.rainbow dqn.ipynb +++ b/7.rainbow dqn.ipynb @@ -186,8 +186,8 @@ " delta_z = float(Vmax - Vmin) / (num_atoms - 1)\n", " support = torch.linspace(Vmin, Vmax, num_atoms)\n", " \n", - " next_dist = target_model(next_state).data.cpu() * support\n", - " next_action = next_dist.sum(2).max(1)[1]\n", + " next_dist = target_model(next_state).data.cpu()\n", + " next_action = (next_dist * support).sum(2).max(1)[1]\n", " next_action = next_action.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, next_dist.size(2))\n", " next_dist = next_dist.gather(1, next_action).squeeze(1)\n", " \n",