From bd7b031c5e88ffd279f902ea3ea8d8748c5fe9ca Mon Sep 17 00:00:00 2001 From: Georg Martius Date: Fri, 11 Jan 2019 09:07:02 +0100 Subject: [PATCH] winner: -1 for opponent closeness to goal proxy removed --- Laser-Hockey-Env.ipynb | 373 +++++++++++++++-------------------------- laser_hockey_env.py | 11 +- 2 files changed, 138 insertions(+), 246 deletions(-) diff --git a/Laser-Hockey-Env.ipynb b/Laser-Hockey-Env.ipynb index 9b13a17..cb53e0a 100644 --- a/Laser-Hockey-Env.ipynb +++ b/Laser-Hockey-Env.ipynb @@ -2,11 +2,11 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T12:53:01.181125Z", - "start_time": "2019-01-08T12:53:01.082265Z" + "end_time": "2019-01-11T07:55:28.444123Z", + "start_time": "2019-01-11T07:55:28.441688Z" } }, "outputs": [], @@ -19,11 +19,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T12:53:01.215186Z", - "start_time": "2019-01-08T12:53:01.212826Z" + "end_time": "2019-01-11T07:55:28.642643Z", + "start_time": "2019-01-11T07:55:28.639000Z" } }, "outputs": [], @@ -45,11 +45,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T12:53:01.841117Z", - "start_time": "2019-01-08T12:53:01.825214Z" + "end_time": "2019-01-11T07:55:29.539930Z", + "start_time": "2019-01-11T07:55:29.525220Z" } }, "outputs": [ @@ -59,7 +59,7 @@ "" ] }, - "execution_count": 4, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -70,11 +70,11 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:15:40.376470Z", - "start_time": "2019-01-08T08:15:40.369876Z" + "end_time": "2019-01-11T07:55:29.930674Z", + "start_time": "2019-01-11T07:55:29.926764Z" } }, "outputs": [], @@ -86,16 +86,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "have a look at the initialization condition" + "have a look at the initialization condition: alternating who starts and are random in puck position" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:16:21.590059Z", - "start_time": "2019-01-08T08:16:21.577172Z" + "end_time": "2019-01-11T07:57:12.690181Z", + "start_time": "2019-01-11T07:57:12.662104Z" } }, "outputs": [ @@ -105,7 +105,7 @@ "True" ] }, - "execution_count": 17, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -116,13 +116,20 @@ "env.render()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "one episode with random agents" + ] + }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:16:30.112883Z", - "start_time": "2019-01-08T08:16:21.744167Z" + "end_time": "2019-01-11T07:57:22.247098Z", + "start_time": "2019-01-11T07:57:13.886809Z" } }, "outputs": [], @@ -132,8 +139,8 @@ "\n", "for _ in range(600):\n", " env.render()\n", - " a1 = [1,-.5,0] # np.random.uniform(-1,1,3)\n", - " a2 = [1,0.,0] # np.random.uniform(-1,1,3)*0 \n", + " a1 = np.random.uniform(-1,1,3)\n", + " a2 = np.random.uniform(-1,1,3) \n", " obs, r, d, info = env.step(np.hstack([a1,a2])) \n", " obs_agent2 = env.obs_agent_two()\n", " if d: break" @@ -148,22 +155,14 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 26, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:16:38.452297Z", - "start_time": "2019-01-08T08:16:38.409460Z" + "end_time": "2019-01-11T08:00:20.475049Z", + "start_time": "2019-01-11T08:00:20.312847Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Player 1 scored\n" - ] - } - ], + "outputs": [], "source": [ "obs = env.reset()\n", "obs_agent2 = env.obs_agent_two()\n", @@ -180,74 +179,91 @@ "cell_type": "markdown", "metadata": { "ExecuteTime": { - "end_time": "2018-12-20T20:37:41.013424Z", - "start_time": "2018-12-20T20:37:41.009298Z" + "end_time": "2019-01-11T07:57:48.631793Z", + "start_time": "2019-01-11T07:57:48.627528Z" } }, "source": [ - "# Train Shooting" + "\"info\" dict contains useful proxy rewards and winning information" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 27, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:16:58.302953Z", - "start_time": "2019-01-08T08:16:58.296373Z" + "end_time": "2019-01-11T08:00:20.784862Z", + "start_time": "2019-01-11T08:00:20.779373Z" } }, "outputs": [ { "data": { "text/plain": [ - "" + "{'winner': 0,\n", + " 'reward_closeness_to_puck': 0,\n", + " 'reward_touch_puck': 0.0,\n", + " 'reward_puck_direction': 0.0}" ] }, - "execution_count": 20, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "reload(lh)" + "info" ] }, { - "cell_type": "code", - "execution_count": 21, + "cell_type": "markdown", "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:16:58.694873Z", - "start_time": "2019-01-08T08:16:58.688106Z" + "end_time": "2019-01-11T07:59:24.867441Z", + "start_time": "2019-01-11T07:59:24.862324Z" } }, - "outputs": [], "source": [ - "env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_SHOOTING)" + "Winner == 0: draw\n", + "\n", + "Winner == 1: you (left player)\n", + "\n", + "Winner == -1: opponent wins (right player)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-20T20:37:41.013424Z", + "start_time": "2018-12-20T20:37:41.009298Z" + } + }, + "source": [ + "# Train Shooting" ] }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 29, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:40.797119Z", - "start_time": "2019-01-08T08:17:40.791049Z" + "end_time": "2019-01-11T08:00:32.294924Z", + "start_time": "2019-01-11T08:00:32.288528Z" } }, "outputs": [], "source": [ - "o = env.reset()" + "env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_SHOOTING)" ] }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 36, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:40.952762Z", - "start_time": "2019-01-08T08:17:40.935735Z" + "end_time": "2019-01-11T08:01:01.754465Z", + "start_time": "2019-01-11T08:01:01.728781Z" } }, "outputs": [ @@ -257,35 +273,28 @@ "True" ] }, - "execution_count": 56, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "o = env.reset()\n", "env.render()" ] }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 37, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:42.440148Z", - "start_time": "2019-01-08T08:17:41.488133Z" + "end_time": "2019-01-11T08:01:05.573939Z", + "start_time": "2019-01-11T08:01:02.243221Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Player 1 scored\n" - ] - } - ], + "outputs": [], "source": [ - "for _ in range(60):\n", + "for _ in range(200):\n", " env.render()\n", " a1 = [1,0,0] # np.random.uniform(-1,1,3)\n", " a2 = [0,0.,0] \n", @@ -308,11 +317,11 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 38, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:43.974385Z", - "start_time": "2019-01-08T08:17:43.967424Z" + "end_time": "2019-01-11T08:01:07.630627Z", + "start_time": "2019-01-11T08:01:07.625675Z" } }, "outputs": [ @@ -322,7 +331,7 @@ "" ] }, - "execution_count": 58, + "execution_count": 38, "metadata": {}, "output_type": "execute_result" } @@ -333,11 +342,11 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 39, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:44.398292Z", - "start_time": "2019-01-08T08:17:44.391478Z" + "end_time": "2019-01-11T08:01:07.981240Z", + "start_time": "2019-01-11T08:01:07.974283Z" } }, "outputs": [], @@ -347,50 +356,26 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 40, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:55.822383Z", - "start_time": "2019-01-08T08:17:55.815069Z" + "end_time": "2019-01-11T08:01:08.317742Z", + "start_time": "2019-01-11T08:01:08.312949Z" } }, "outputs": [], "source": [ - "o = env.reset()" - ] - }, - { - "cell_type": "code", - "execution_count": 67, - "metadata": { - "ExecuteTime": { - "end_time": "2019-01-08T08:17:55.998906Z", - "start_time": "2019-01-08T08:17:55.974731Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 67, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ + "o = env.reset()\n", "env.render()" ] }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 42, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:17:56.933957Z", - "start_time": "2019-01-08T08:17:56.192972Z" + "end_time": "2019-01-11T08:01:17.885328Z", + "start_time": "2019-01-11T08:01:17.159573Z" } }, "outputs": [ @@ -424,20 +409,6 @@ "# Using discrete actions" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2018-12-21T13:08:25.404585Z", - "start_time": "2018-12-21T13:08:25.392585Z" - } - }, - "outputs": [], - "source": [ - "reload(lh)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -508,11 +479,11 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 43, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:20:41.803577Z", - "start_time": "2019-01-08T08:20:41.790700Z" + "end_time": "2019-01-11T08:01:33.837983Z", + "start_time": "2019-01-11T08:01:33.831404Z" } }, "outputs": [ @@ -522,7 +493,7 @@ "" ] }, - "execution_count": 79, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -533,39 +504,25 @@ }, { "cell_type": "code", - "execution_count": 80, - "metadata": { - "ExecuteTime": { - "end_time": "2019-01-08T08:20:41.980991Z", - "start_time": "2019-01-08T08:20:41.973284Z" - } - }, - "outputs": [], - "source": [ - "env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_DEFENSE)" - ] - }, - { - "cell_type": "code", - "execution_count": 81, + "execution_count": 46, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:20:42.177107Z", - "start_time": "2019-01-08T08:20:42.171320Z" + "end_time": "2019-01-11T08:01:45.035969Z", + "start_time": "2019-01-11T08:01:45.032057Z" } }, "outputs": [], "source": [ - "o = env.reset()" + "env = lh.LaserHockeyEnv()" ] }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 47, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:20:42.429795Z", - "start_time": "2019-01-08T08:20:42.365137Z" + "end_time": "2019-01-11T08:01:45.572163Z", + "start_time": "2019-01-11T08:01:45.504321Z" } }, "outputs": [ @@ -575,22 +532,23 @@ "True" ] }, - "execution_count": 82, + "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "o = env.reset()\n", "env.render()" ] }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 48, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:20:45.648933Z", - "start_time": "2019-01-08T08:20:45.644750Z" + "end_time": "2019-01-11T08:01:49.157281Z", + "start_time": "2019-01-11T08:01:49.152424Z" } }, "outputs": [], @@ -601,36 +559,11 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 49, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:20:46.058583Z", - "start_time": "2019-01-08T08:20:46.052750Z" - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "0.6865685573493514" - ] - }, - "execution_count": 84, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.random.uniform(-1,1)" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "metadata": { - "ExecuteTime": { - "end_time": "2019-01-08T08:20:46.376265Z", - "start_time": "2019-01-08T08:20:46.372358Z" + "end_time": "2019-01-11T08:01:52.268233Z", + "start_time": "2019-01-11T08:01:52.264406Z" } }, "outputs": [], @@ -640,22 +573,14 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 52, "metadata": { "ExecuteTime": { - "end_time": "2019-01-08T08:22:07.314813Z", - "start_time": "2019-01-08T08:22:03.865489Z" + "end_time": "2019-01-11T08:04:40.694729Z", + "start_time": "2019-01-11T08:04:32.333471Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Player 2 scored\n" - ] - } - ], + "outputs": [], "source": [ "obs = env.reset()\n", "obs_agent2 = env.obs_agent_two()\n", @@ -759,65 +684,39 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 62, "metadata": { "ExecuteTime": { - "end_time": "2018-12-27T23:29:58.306819Z", - "start_time": "2018-12-27T23:29:58.282514Z" - }, - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "reload(lh)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "ExecuteTime": { - "end_time": "2018-12-27T23:29:58.703178Z", - "start_time": "2018-12-27T23:29:58.690611Z" + "end_time": "2019-01-11T08:06:22.038375Z", + "start_time": "2019-01-11T08:06:22.035338Z" } }, "outputs": [], "source": [ - "env = lh.LaserHockeyEnv()" + "import time" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 54, "metadata": { "ExecuteTime": { - "end_time": "2018-12-27T23:29:59.125733Z", - "start_time": "2018-12-27T23:29:59.115927Z" + "end_time": "2019-01-11T08:05:10.184886Z", + "start_time": "2019-01-11T08:05:10.180414Z" } }, "outputs": [], "source": [ - "o = env.reset()" + "env = lh.LaserHockeyEnv()" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 55, "metadata": { "ExecuteTime": { - "end_time": "2018-12-27T23:29:59.525158Z", - "start_time": "2018-12-27T23:29:59.441172Z" + "end_time": "2019-01-11T08:05:17.339971Z", + "start_time": "2019-01-11T08:05:17.276199Z" } }, "outputs": [ @@ -827,22 +726,23 @@ "True" ] }, - "execution_count": 13, + "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "o = env.reset()\n", "env.render()" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 56, "metadata": { "ExecuteTime": { - "end_time": "2018-12-27T23:30:04.374930Z", - "start_time": "2018-12-27T23:30:04.341725Z" + "end_time": "2019-01-11T08:05:26.969723Z", + "start_time": "2019-01-11T08:05:26.966375Z" } }, "outputs": [ @@ -867,11 +767,11 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 63, "metadata": { "ExecuteTime": { - "end_time": "2018-12-27T23:30:55.123198Z", - "start_time": "2018-12-27T23:30:52.895958Z" + "end_time": "2019-01-11T08:06:26.444905Z", + "start_time": "2019-01-11T08:06:23.849965Z" }, "scrolled": false }, @@ -886,6 +786,7 @@ ], "source": [ "obs = env.reset()\n", + "time.sleep(1)\n", "obs_agent2 = env.obs_agent_two()\n", "for _ in range(600):\n", " env.render()\n", diff --git a/laser_hockey_env.py b/laser_hockey_env.py index ab3d081..b5f51ba 100644 --- a/laser_hockey_env.py +++ b/laser_hockey_env.py @@ -44,7 +44,7 @@ def BeginContact(self, contact): if self.env.puck == contact.fixtureA.body or self.env.puck == contact.fixtureB.body: print('Player 2 scored') self.env.done = True - self.env.winner = 2 + self.env.winner = -1 if (contact.fixtureA.body == self.env.player1 or contact.fixtureB.body == self.env.player1) \ and (contact.fixtureA.body == self.env.puck or contact.fixtureB.body == self.env.puck): # print("player 1 contacted the puck") @@ -466,14 +466,6 @@ def _compute_reward(self): def _get_info(self): # different proxy rewards: - # how close did the puck get to the goal - reward_closest_to_goal = 0 - if self.done: - max_dist = 10. - max_reward = -1. - factor = max_reward / max_dist - reward_closest_to_goal = self.closest_to_goal_dist*factor # Proxy reward for puck being close to goal - # Proxy reward for being close to puck in the own half reward_closeness_to_puck = 0 if self.puck.position[0] < CENTER_X: @@ -495,7 +487,6 @@ def _get_info(self): return { "winner": self.winner, - "reward_closest_to_goal" : reward_closest_to_goal, "reward_closeness_to_puck" : reward_closeness_to_puck, "reward_touch_puck" : reward_touch_puck, "reward_puck_direction" : reward_puck_direction,