From 5b22799c3c2020e5cfee61ae8456b78802c21ae7 Mon Sep 17 00:00:00 2001 From: wassname Date: Tue, 13 Oct 2020 11:09:06 +0000 Subject: [PATCH] fixes --- .../Intro_to_NN_Part_2.ipynb | 657 ++++++++++++------ .../Intro_to_NN_Part_2.py | 11 +- 2 files changed, 471 insertions(+), 197 deletions(-) diff --git a/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.ipynb b/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.ipynb index afa01ac..dc3234f 100644 --- a/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.ipynb +++ b/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.ipynb @@ -5,8 +5,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:00.497080Z", - "start_time": "2020-10-13T05:57:00.151314Z" + "end_time": "2020-10-13T11:01:34.137781Z", + "start_time": "2020-10-13T11:01:33.730220Z" } }, "outputs": [], @@ -130,7 +130,7 @@ "Credits to researchers at Georgia Tech, Agile Geoscience\n", "License CCbySA\n", "\n", - "In this notebook, we will be using the landmass dataset, which have been preprocessed already. In this dataset, we have images of 4 different types of landmass: 'Chaotic Horizon', 'Fault', 'Horizon', 'Salt Dome'.\n", + "In this notebook, we will be using the landmass dataset, which have been preprocessed already. In this dataset, we have images of 4 different types of landmass: ['Discontinuous', 'Faulted', 'Continuous', 'Salt'].\n", "\n", "This is an example of [seismic data](https://en.wikipedia.org/wiki/Reflection_seismology) which is a way of using seismic to image the structure of the Earth, below the surface. These waves are similar to sounds waves in air. The lines represent changes in density below the surface.\n", "\n", @@ -142,13 +142,25 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:00.531948Z", - "start_time": "2020-10-13T05:57:00.498549Z" + "end_time": "2020-10-13T11:01:34.157752Z", + "start_time": "2020-10-13T11:01:34.139511Z" } }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda', index=0)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "device" ] }, { @@ -156,8 +168,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:00.900212Z", - "start_time": "2020-10-13T05:57:00.533714Z" + "end_time": "2020-10-13T11:01:34.545071Z", + "start_time": "2020-10-13T11:01:34.169245Z" } }, "outputs": [ @@ -196,8 +208,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:00.904680Z", - "start_time": "2020-10-13T05:57:00.901784Z" + "end_time": "2020-10-13T11:01:34.550634Z", + "start_time": "2020-10-13T11:01:34.546598Z" } }, "outputs": [ @@ -221,8 +233,8 @@ "execution_count": 5, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:00.953802Z", - "start_time": "2020-10-13T05:57:00.905936Z" + "end_time": "2020-10-13T11:01:34.646797Z", + "start_time": "2020-10-13T11:01:34.552585Z" } }, "outputs": [ @@ -255,8 +267,8 @@ "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:01.067516Z", - "start_time": "2020-10-13T05:57:00.955175Z" + "end_time": "2020-10-13T11:01:34.733743Z", + "start_time": "2020-10-13T11:01:34.648377Z" } }, "outputs": [ @@ -264,14 +276,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "Class: Chaotic Horizon\n" + "Class: Discontinuous\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAABjCAAAAACqUOFJAAAfXElEQVR4nC3BWcyl52EQ4Hf73m9fzv6f85/599k9mXFsx+6kTpyQuFnagJpSiYICXIBQQe0NElQIJLhFSFx0AaEWCqWlLW3VJF1om2CnsePYju0Zzz7z79vZz7dv78YNzwP/ZwFFnbdKx2mYESeiyHyJFyTy6mG5XBRaEktucZyCKi30aExis9lGCpvnAa0ND7lsDkvWYBiQdlKqSWUyPhsiISI/FLStxrai5C4CgsNj5bOOf7qSJKDmRZwDHZVXQ1VHdOInkWwiJBoFoUPfMx1XnWdlJIDlYKJmQriKuIYsNQZl3VCKZroOsG/hQY19APsgg1/TIr3hnegoZhj6U2Si8yoKSEMSEyIDIuEU9Tyw2gZMeDEMIM7hUYFDpLsZ1WpcKUtyxq2acKzlxgBRo7bHhtYocCsUelH46Bz+fZUx25vbGGaJu4ibCs+Mpq/ZGgq9nBCtsXC1UelqtPDLmRpk1QKmuo5BAI+kBRFokFNUxqh2Ks9G0BVOVehMM5WgoOKRdBMNEsPemoV1i9omTxlpWxlZhdfj0JUEiUAmsi6MpEjLyvYNuLUb666Ph+XYi2DcVAAJpnHbJVIoaBg6ndbVuEjLzZguMJhb2lE7n6k+WSiNuJrJWDk3E10jzcZgnp5PndzFkwt8oSZ0GUqb+d0Wkm1srDQA7JXHOY3IqlGmsFZyy2pKYXBPZRXC+ZKmbL2APVC5rrPmy0u1Bf9xVVHfJqhYnHtKJ3av5fF7S1IIkyZe7JDUzQFpsF6jyKXTcbFQ9Um3nVRJjHSUayqvtxIH5pAhNZGuRZmWyILoBq2AI5xp2uQluTGNCAPGUMY7c6sWgc+y/nbaH0GlF/PigqMYdRDPIMtom/PRaehGH65f7we2PWKuUULKkDZvq/MJgmkLLVZgosoj6ZE+0dk0OdEqUcLfxhItoKMbEJ9YoDTSpaka+nZYYrE4LzSyUnd1WeM4MrV2+fBoUXaqyijWhw2aEyxC5ULZGFNtEbu4NuXCxLAj0wVOmgHPEgFbKVSE+VbWi9mPsOUHuFtHCk73BpsElSRXLe8gXzS4JkyIpjofL0+U35D9nfTeNKfDIYkrwoo0DvBH1vo1WOep8s60gR6gLM9TyqSTr1AoNXKw9CdOFt9RsLUly+JI8X0c7b7VKbm23vSulDPtyCHF8fm4NLX66mWgjbpN89Pw/ORUwNiysuioEBeOZV1NZ5XUgpEc0aCyBcgRN0kBIF9I+GWL0tP+wnVKVWa6VhGzvqaf3/ViEayZzR0xDx/B1lmatzSrHt5ohaKGi3JreHg8GufYNMqU69zq4/OTEFqK+Hp9zigyYOImpG3xyEB5RDhfrpa+3O7p+cmjezdWS3PTQ0b3DFqCn9x7l3khXCbiZrPRXmYXDj+oigtZmi4ewGYrzOkx0OkVT4w2XDPz1oeLwnDgo8Rf4YsJlaoy1FLgoQ//nc9V5mKsmRa89+SqBfQgrHBzovnROJveyXvdS6dmf82OzDlBj3YZGEiLg/PWJ6jE1VnpWD0jmWtqQABtwgmqTDz2bBrhLC7qwsRzuD6A/8noZNozPT+onJ6vzfOuMSsnVf/yriDRRuM+pH20RN3j0bTqG+SGXPBUqoCNsYXctXmLQHYKuVwk26va+aRbqcz6VHU0sQcrgjF5MLhcgkSQwxnyG/dD/9CCo9v5mXGWpzAJLKLmkJ95L/bq6KGFmKFotmnjV4J6djibtkg3vcedqlKdeLYwmqwRjOcmPt8LKkjk4Pgk7bT1wJdsmjn6ySH8l+mMy6Xatra0hX/S0E+O+3bvyvS9K3wfnLW3vXwS9kptvUqiq1dOx3RN1tM4aLHqQ7TK/aMiDndoT2rszFZNegSW1vKD5jVcPoa4tT7G8YGnEZ24n+zOH6qLemBku7OL+Nh9ac33wz/9IbgA3R6q3ojd1RDXIxev7qeTN8SFS6vGum4z+DncDj37Adm6OkXdeLl2SZ4NL02WmSL31q5euxTzVJNtlxrVhWvwXzgvblald2+cnoDuy7uLzRvj8ePvv/eiY6xuf04cvz8xu5MNsnCE9f5fBJY7W9+5tLmIlOqhUZ6vm+aAHU380EcGuqet2/wpWn121rt5Q7DFOQygpsTAJPaf/MHG5/Kj47GvX1qtP2kUj//448fwyz8/eb/RODo1fjJa5OuWrgfs/fh0ert/fjfJxycZxFQ+QemlVy48eoBGc/O5/MEsCoefsq2bXwC/9s6bnx+i1dw7mJqV8eQQ/uq7H33Up2zngv28g7u3D37v149B57P//Kp39/+O3ri//U9Xg2jj3iIKxO8UqzMJ4tGGo7sa99OMA/9pcmW2a5fLLrLCU6jAxsbf+xt686Pf//0zR7sCi7MqMMp5Cf/tZf7hXae5+lI8TKPh7vf+B7/4/Gc2dt4Ptv/3H34IQPD8F1tqd98g8zz6An3yYD13E3Izm1cGb466j76vAJIAaRW4qg2yiL/8qeCO/ZXy7n+9xxhU3fVWf3I0hL/Wcq/m2uTRVhBr6Tu/q+KvfnW4eva975x82fngvJf9OWhePqKq4c3/bn7WWB/pxtFSGuHs/UaDJvQzW382GjKndZqQGy9srZwn22+/edz+uRfX7uxWb+i9F11dhyWA2Q8OL75Czg/A1n758L9rt79whbUODv4Cnfpff/4o8b77w28CJDvsyqe/FvzBwWeu7K68f3T+sOc+TlvwOf8f3J58D90c8sftw9lVo4G0yd13C8j7N9Yv4YrbYAKiTXIA5+nRaUOu7zx7+vjkIX3t5hefHnh/jfoXzcfXL4MfrnR/8Nsf7Zxb4fovgXb4wfEN3hr/4Pjs0vXNCWy/MBDT9bRqAQAAeJTgXRK+9/JVfPhnhVi95t0wnTvziA8HCfz5Lzfi+8uNmw++NQmaL5BOEGpPZoOfPrGLY9HI3ZdPikczN7PErZF+u/r24XPF+h3RwL2XjdIB8g9/78YL3eKkd11Yi7f788XbDy+/dok+5WfRKLmxItP+Lk8bn4efuR5sHZdnBloG16+fPCjArfLDlwz0F8nFx0+++rpE7kb2g37Rkd+JjJe6/l8GyWXT7S4S0KzEw7NdNt0ZTEIYtG6iR8Ut82T2/bPrL28KTxwueGtJLubjj6sX4H85GXeAuV92ty4U6GH+w2D79Man9v9YpQYwP/va8r0Xsa/Pc4r/z4GNml+qTsxyRwdoeo/LUxld+rHTiYheyPYy6Pfvrq9/ArzxJ1HvIt8erNMTHJ9oO2jx7C58PxofwzjgkeedHPW7b5Xi2s+m93Zf9t/b7pd5vr/Kb752uCj2tRO868DrhfukDcvO9JjZrJt0/dZeAW+ulMff9dfIs/FW0wrvTEUgyS2XS3o6X9m58uy78N/s9B6BxxvwaW2orMOybONnp++nl7vW0+H5yVManGxe2bbS0V9fXDHO1QfEraVJgX3udmxqZk+OhrU+eLx2feOgfLzy4EcHly8NOVh0FkklW4A4OEN97kK6/VLPF1DymsTt4x9c/dyn0a8+XfvE8rwTP+pogkZb+eym2D9ALwb2hO3XmeeumpG61I6OWX9vwjY1942p8+qV0ft3MhW/soF3/DTxew9SrRK9Yb1739iAAICXd14pE0eNl/b++Otfm/3+PXl9ZXzqnhg7G/x0PZnG/cmhfbiC1vYpqPmefnst2m7Fe2fZRofMB9k0uPsotU9aR6DXG27q7az4ePhcvYxDwzZ88HD0Lvzc/Qnp3B7ysbY7WGjf+Prxb9zZ7HlwjH/kXdZ4C6+NbBw9eUDmWbk+uWu18T545Tbtxou4n8632vDOI/p1Hr59Z7J1YysfvVu9JI750mzuhCmfYWPrqvfRH8HHyd4vf8y2G/eE1m+88IvGtx8OG8F0L6rf3OriPXbTaK9mUj56WO2BzY0Pl5q/A4Nttpd1Bo3xvmyu/tl7V1+9cpHnqfkKPH3vd94BAIBWiU0Lmx+D4Or6eulCBcCfvPXOeVjFt4qf+cr8zfKzA5Ydvrk8EZTZ+UOnGu7Ym8+vPn364aL1j35svBu8jN/68PhNtnn97CHmoLUOWfTcznNXXQAmKR7/5jcX7c7Lp9WTYZslJjgvBz/xi/BbADTs8q/+fNVbsW5f/627q6+cNcLDd57ZrxpsZrFHjxPv9fbVbwTnk3cXP3kdAADA27+y2NuXQ02xoCj/9at//Zer2GuturnnDrvgT6fXrtgH98+xEX954/v/4YPhP3wN/sKo/vLal04+Cq6crR3xRwZ5/BbWIpS/8LUt/zDcZm+/uX9NWq++qmvg5MO696nZIX7vN2MvitXtLx3JQvx09/03XHEOCgG/uA7d18H/V0zMLgBv3cMXJvCfyN7tw8vbfQJS8cGBPpj/7qMHKxde/fHu0GiDGNtAfOtbm4Oz8PJLt6pf/43ik8AAz3/w0c4XrGTjOZWXWvHo7FHsg0u0DCvsjTb/tvdkz3MuttSi88OkfwuAhxFh/Qtu/7vnw/6hHh192Pz4nb3+xdb1S68kd5/A3upp48Xmje+vvnBndO+Htzb7fyecmZhfuvVx72+Buvhh4yIZPTPW5bcmNy9d9ck373Yb4Xfp3oGx+KR50Xl479lq9uK9hy78KYvbnz+CYriUxUcnGw8e7Lx+M3Hs1fxolKCN6L3X+wo9Wen64TtHGy++rp0XJ+Pik6sefjaJdhefaE/2blyMf/ntwc6F1kvV1G5WJ6BuTM6OWGDtX7DKw9kROCd0Onk2f+nR8KgdPnmyAfhXX9voTRr8PLz2yeVh2rxylsW3/vLsq1+7tf1W/DFqDkua/a/v/c2LT85Gp6+pt50d54Hb+1p57z6DN77yqf3Tpi10PeidHb7Bp5d20qOHQGPk5jLWPs4evtZ22PKlG3uv/1z75M5gax5Ds57cDy+/8gMwfCAOf/PuT31u60f3v1leAdr8e+rs5eLSDbdx6zsTsv3tyLvyE/13nlTfjw+eHfa3K7Tb6s0nYw0/O0svvbouIPwtyUZvPbKa69pJ8BW5f6kfn54875zn1sl7T+LGyvVq/fLj1v3/lts/495lB1mil/1g6K8HL22wj7aS+4vLZ8dHg08Fp39V3K9QHQGDmFPauNpvcBSPqos9x16Hv1Y5+uTeCx9nh4dfff4pb09HaoO939jc+uab9qvWB7PPu9GlT2vf/t7HYJ01m409xS7evjEC3vxQPQPbs482P/N0Pt9qN59W+/NsHc8XQdU4Ny7evnVm1E/cxWTWpOTspNNqvdyv32PPB7t7/Il1Wsjd+xsoyn/si59N1h+oe3fjbO0nX/6tO5evEdC5nXRpH5Z19ca3KgA6XnYUVpvDxQP3hr25rPqfnO+LNHjC8X5Z6vmJOYnhvID/6kw6vjEnB+pL++NlCSkLoUbO83rzGzfmtLf43fdcXLGv71RP6HP9h6wZ6uF8errMS6nXgDWb74zcWzfqMH7tgl8cz165cDTy4kW+fCTaqoKJ39gcaYRfmJMZXErXL2eh3W8+XNlhvvnxu3DQ/+i9wSe4c+3a+Xz07v6P33zw7tUcjKsHNcuj48Pb31irzs5Mu/Gu3dEbw2ip75vz0+9cYxCUlk5kjnEnqLrY2fDgNwYKgVrOBvLUm/rX2h/afTyR8L6z6u6l8oTvXBDcpWPWc0ZP5LAfqTtBqwjK483X5PbBUTt2cCY5nKSH5NQZtEPkD+O0xlovm9kXqVMuFsAi5dyq+7JacTLzigPtqV0vqoVsXfQXH7ca++egtR0DSd3O+NTYeTqqUmYHTF9pXp++OXZxtXKwdbM5jth8LmZmXWMjz45HdYoDkIv8sXTmpYwQ/AWjSv0KXOSwCUuUPDKVSuzCXa0OZoMyVYHWWGV7ZeKCFF4Uo1mN18Fs1WZoMbM+aHqyHg07VUVWhH+0lmU5pCiVEBOLwUBEgFPTEQWxvMrNaMMrdXLcwYeaT0SvzgfuMyd5mqwM1+ETrcGZ/UGvDXdXr+0Xnc6yofMDTr2BrHF+88Nyb71lmO2gAQx3psXE9UmfojPk1B2VtHXWzYgoPKgreKiZ+7JcTG8FQjQmTTPmPZhZgQ4E/mitEfRcYHYezK7rXK862lFzm4nk/FqW2pdfmJ2118g9FZmjenUNPkG6DX2TDSJEHNSEcsEkGeZ2qPBkQe3UGoerlAk1deVZHuC1wJRWvryI66aMPw1m9NrZgRacJL28qi/LWZHXfWBMN4Kd/dx1J3NTswAofJwn9IDkzpKaDDq1UOeYrMWVDE2XJuKFqFq9fJIUznxnSZr+CDlERwuBbz1L3aO6w0bt7q5wUi0lm4tJyayVsBGEB4nvP0tbPZNWHWtyluheyoNIAhyogQhLV7tQS/grPNRisWnk6UYGRf/RLOtIVA54mExXqFmwxjJI84bag7rnK5zPej25IveewobVHUnA+OhCbjDa0klpyBMqKoJUQ6K2aqcrYBS7pmVU8J8JuO2kyLXyhYFL7Ti1m/5Bu5rUF2rHimp9cD4NivIqmRwPRbZhhqRdGKNsPPBYbYzPWt5krVjVZgVuoyyGqzqbIRQQTY+Q0kqmUggrAn+pbg5a+C7CRpiAwXTe7FnGSft+uL62shCLZpb4JQ9T87P+/rwQ8257rMFMeKsuj+emCK069y8Y9YnBEEtkn4B+lgCjEossb+aF4UwFTNtkhzSSNGhkC6eX5RpZ7UFcGZnbGhSz8cR2673hxgJ1p/duXH24dM9GsQF8XfO6RyX2Q/diGrMoSOa8p6pwQfNllpp1mSsjh1jVBuGG0NoB/I8N8zSh3TSTlM2x0xShEnXZbrNoCuIGsTC2HEi0Q7SpzUDFjjashh7KaAxcO4d+DkWe0ajVUjZ4lmKvhtLJ027XxIzHhqV4KTsZ+ci1ZcVhZbFna/rCjatKlDjPmXJrulGN7JvVse4wVz767iUxXTWCTMRakeZdLo4dljLqOHMO+VnebPVI6JpJjBvNQBYUApgTo6758gj+HHI9xamkzrhlMwAo5lkjEk0z9SpvFmmrmKeJPXIWJ56bM9zM1FLvCWcYp+eUeimjCCkCWa20lp5ErjKBK6u9xHNxnEEbcgHm5CXBYYKVlcPVjBqzHqCSBvaCj8LmEtcGf4BWZyNzImivBIPjjDjSoo7OF0x1YsODrDrf0h2YulnB7JVxgQOe5SRXMatZwWPLLCsH/lE55iFHPkM8pFT4kLJc2Utw3kpn0PWMKjM8vOiP8w6mtlJFB1XcwCe1YVCISqajvVbb1uftyikqYxkVISZam4CYFxbTlXBA3iUH1dg0zLpqhblRm+7IL9JI2pVmrWQtZOpoqYTZ09cuzFOdzFd8qeM6tZgtRVUM8xxXjfWUJYVII2iKcQD0GjUbzGNdRgxhxkvSKS0ywjUKXB5hy7OAkSwkL3ETrdQOHK6oRACOXSRa54EPclidzgJGWOUUqK6EdcaJmWUWYBVzs/lh35uWyLiammgKhK3rHBaFBpRckIGxFBi4Wo1dZBTnXl7rqzYdLNBoSvV57HpU5Zk3edKlM89ZjseZxQEoTB1JZ2IYguaFnoASUP2cKBMp6aioBKH+WJdKZqSh0hoQX+O8jBKvLJcaKKcbuqTtpT62sPUsMmNeDvU9RlXrfOZUYaeJOaeUaXrd1rnRQixz6zzJhax7PUfAllZnZ5o03BFcmKUqpYt5bjnw3xuiYnXU1iJuVo5yCWcolJOd1kKSnLkFbpwyiK8sSq06xX3NBpzEA7wwOEv6MV5SIEMuOfG7flphjPJEc2wZA6Tns0JHUrImIFI5Euoaxl7AVUMUwJuP/LEzT+NPkLFpFjPaFWEeN7lQcTn3TZpiUUCS5hQA4jNmUo/rmZidRhJzZBpwXvO6GXvSnjdpAmxJ5qTFLVoRK4nMoEAZYByWurW6/WjppEllFKQJbM7FkWapYEgn9RGNK3tP9rGidkgzYlrMKA2dpXllASozDlAMUClHEgFX6JYmVYPEUeqQXBe8PtfFeQOXYdVM12jbb03LFhANPxdJwzibYcdeM9ujM9wOnSU2KbLLCmTEBfW0Khpap8AqgTWaya6B4cyBmugale+XXGuRXTa1WmmOiCigMnUquSEU220MddpszyWx9EhYNjYSuw5Dh65AL0Su6acVT3FJxALCDCW1uWL7Sy5IMy9Mo66ZataVLTSQK1lxwuwqNjCzKa5LvVsoPUCLbsXrmQatOK35xDCM0DZtKpYZwsBFpdCaWJSVILVB5YwagbEyJcwG1ObQR0XCsDRKFpECKxA7MMtJwygypedco2VlqESpFdsxCJicB2VzCRXN0Zp+ZpHKTnXq5YDrETGqPBe6bkY2Iw6qbDAwjcpgjq6KFCLJZULiaWlqWmGpkuSA1KEqUlxIQ88sHGpFedofLDqzaJqlM6BjKbPTNqlZoGmU+hKksvb1DGLiOMyQCe4XOatOsQ7rVMtUmVWGWYyQlXOU+Dwnfo4Mm0xizBPuwsIqFZtrWAtNVEWZu5IEsaWipcOmfADKRm7WqdIxPV2qQrpWWQFm5ZqQ0M/IgSACSktio0ySRrLwhGoXka1cxbgBS0KbWGBm6ojaCvmluxZqiXfq2VTYuMYsFtTLF2Y1zfG4NHXpHmRnka23CjNqlcqvKo/KOswSQKmLMWGM1dKRtU9lgyurXhYVnJMh1UVSzZSBlQG7QrTk3AAWyBKet3K3PYXIzHZNM+blojRMRTKWsyzVTD/MxppeygbUM6UpCmWqRAZclDqkFEZCtXmtUQknVkEoQBaw5ULTSgtSb1KK2hgHi7xQmua1/HYensLl1PXgjBe1Mhca0P1QpkFhL0UVifZcE5Xlu2bFMslzaOCClzpcQFhTAnSMlRWQsMSdwlxVlJqIsxImWpsXWGV6bYRaVomsmJjmcqpbPspKaFQ6drFOmTcxZw3GneaJvXA0HRIJNAhLDnLKdMdSHpKWptURHYiAYJUvIDG1SgRLDcRaaptJb97LqGYu4rxUwh5KQ0uhtU7kWdUIsIuWRHfzNtCpkbWrlkktTAtuaIZES5WqBuihoCA1gsQQEAXSIA0zEVpZl1MEQ5dKBDLHMGGnmbUTK82o1DpupnXSzPGpZy9dAJ0Y1gpGDUOrzbSKNa1LdZKnOgYSOsSpCTYpgcuaAEiZ3kRmRaDLFGW4maOcJgoKWJgoDVIDSmZiHhCgwcBsTAmMkLAUz2FRSMOLdWQmap4DpWuaAjnU9bSgwK4NQ4KlteCSCa+WmuNUSpHYRybF2HC0iT0lSFDTKHMVAj9xmyUHuIhyn6tMiyu4aKpZrSNiGzYA5XRpZ47NdVWzKjeolNhkGCpkI1pD3ZG6EQNM5jXXSYawEE4qqOE7DUlq5sbYXgSpKpsmF5lYChVllGhZjeECZA2LioBNSa6WluY2jWkFKGFScG6ZpGCGgRUKIiFqSjMiGR9pQBKxCCKSLySRfhVwi0AtNgPcThfJwZGkVAgjULVvVua4amTUG9IqJLiu+Yq9yrmqYalRs6hAUfmUMVURmpY6zuUCG3mPlpCasCSyLguyKAlcNvEub+m+LAxgQmCUz7jXsEFLIBsTpQJtrjU83ZIcLi9sL6wVoYu0KBtMIg3beV4nNQLeZNTRqsQqLb8Ueo2J5sNYEOVTXR/rAW0UZQQLI8Lcmjj6eKWbQxuJ1Awrt1UJPV9xxpqKJhVkhe4GdirO9BoRz64qCW1pI0JtYbRq6DXLxHMpLFVSEZNW0iNuAxumDtoGmIUMp4ShTOSJU+Nmq5LTMFuUYjzTNS/hfTATZYXNhN4lq2QpueZQt7R0I84R8l2p20Xe8jRhuAxx1kEJEaUUippkCDgGNiYCNL0KMlammmKCDVqQxBADr7DtULmmQAeRmhBHawlOwxnHiui1jCsZIitnM42CHKCMB4EaQ6NQRV2hlJBqLnRQw/9cZAhADfCE2gVSJYy7fErD3kodM4fWmplDlCduApaKWbRfNyuc11mqAzMnMI10xBDSBBCIlQw60PCWMLX8pXJ1TUAigBkqcqYpmVkcy3rJLDMmyHBDh9FIVjDQkEdtEBss1SBcsautJa2AYSJpezQyCic1GEa40pRdZ66fSA8zJmyaCFMoDRo8anvMySGZ6w0AVGUbxD4xzTQou0Zth8TDy41GnnhOotw87ui1ctq+B2eoEoWKC1sy3jQpXo0UMEhhqFbDzKEeCVBSQpCr5xQTB+oQg64gXlxJWBXCoF0qUdcistKlUw10aPO0wKAiHdhaheeGnPEjUuki1nSkiUKrAhYobqiyVevCWalKUmWSO0CvejQFkjnQwFXq5E1OmkrVXMcRo8JcQjeDWeZHxAGZH1a5Hk11rbRapWloZXJe9iwlTNtcNZchp3lZgiigZZ7qwkLLwgzLtm6YCmJUFFAIWqYAqLpGBLcEWZDOMSFzhHRRmnQGcN3LDhxQQ1QzYR91gtz1E+ZUBmaibmLquZq59CJSUNoiOIGKshh6Oncd3aqzRp0xw6grVS9cOytzRCDWiNQMExowbRqIUqZPO8jiNBpoUDVB0hDV1C5dkWkBBAkHUOgh1jXqmahU1Cr0zKGiAiZBJuUcVAVVXA+sJEJ0IDRecUn0QuBBFmLGg66uINJPWnptVOjS1M0EMJlet5algc8sMm0CgqQaefUEaa4XazWpxCgHvIIls4oCMAulnJCxhTHWsG3AlbParjxGgJMKTEMflrUXc4QtoQ8PyqI0m3mtmUmC4xYwK7M0bIWw9B2prxRIL1OWkkphrSRWEWIdkbKuIXFh4XsJw1UMoAnqGFXcb4n/B27CcxMNEC4tAAAAAElFTkSuQmCC\n", "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -285,20 +297,31 @@ "x" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Note that this is an unbalanced dataset, so we expect an accuracy of at least 52%, this is out baseline\n", + "labels = pd.Series(landmassf3_train.train_labels).replace(dict(enumerate(landmassf3_train.classes)))\n", + "labels.value_counts() / len(landmassf3_train)" + ] + }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:01.137930Z", - "start_time": "2020-10-13T05:57:01.068859Z" + "end_time": "2020-10-13T11:01:34.801249Z", + "start_time": "2020-10-13T11:01:34.735665Z" } }, "outputs": [ { "data": { "text/plain": [ - "['Chaotic Horizon', 'Fault', 'Horizon', 'Salt Dome']" + "['Discontinuous', 'Faulted', 'Continuous', 'Salt']" ] }, "execution_count": 7, @@ -341,8 +364,8 @@ "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:01.188867Z", - "start_time": "2020-10-13T05:57:01.140038Z" + "end_time": "2020-10-13T11:01:34.874511Z", + "start_time": "2020-10-13T11:01:34.803122Z" } }, "outputs": [ @@ -1046,8 +1069,8 @@ "execution_count": 9, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:01.230427Z", - "start_time": "2020-10-13T05:57:01.190460Z" + "end_time": "2020-10-13T11:01:34.925623Z", + "start_time": "2020-10-13T11:01:34.876534Z" } }, "outputs": [], @@ -1088,8 +1111,8 @@ "execution_count": 10, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:02.831401Z", - "start_time": "2020-10-13T05:57:01.231676Z" + "end_time": "2020-10-13T11:01:36.543151Z", + "start_time": "2020-10-13T11:01:34.926857Z" } }, "outputs": [ @@ -1118,8 +1141,8 @@ "execution_count": 11, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:03.149187Z", - "start_time": "2020-10-13T05:57:02.832611Z" + "end_time": "2020-10-13T11:01:36.898830Z", + "start_time": "2020-10-13T11:01:36.544561Z" } }, "outputs": [ @@ -1180,8 +1203,8 @@ "execution_count": 12, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:03.157672Z", - "start_time": "2020-10-13T05:57:03.150671Z" + "end_time": "2020-10-13T11:01:36.947568Z", + "start_time": "2020-10-13T11:01:36.903252Z" } }, "outputs": [ @@ -1189,7 +1212,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ 0.0054, 0.0994, -0.0771, -0.0366]], device='cuda:0',\n", + "tensor([[-0.0955, 0.0532, 0.0204, -0.0987]], device='cuda:0',\n", " grad_fn=)\n" ] } @@ -1221,15 +1244,15 @@ "execution_count": 13, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:03.216252Z", - "start_time": "2020-10-13T05:57:03.158919Z" + "end_time": "2020-10-13T11:01:36.963650Z", + "start_time": "2020-10-13T11:01:36.948781Z" } }, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 13, @@ -1261,8 +1284,8 @@ "execution_count": 14, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:03.422893Z", - "start_time": "2020-10-13T05:57:03.217604Z" + "end_time": "2020-10-13T11:01:37.236500Z", + "start_time": "2020-10-13T11:01:36.965111Z" } }, "outputs": [], @@ -1288,8 +1311,8 @@ "execution_count": 15, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:03.437448Z", - "start_time": "2020-10-13T05:57:03.424363Z" + "end_time": "2020-10-13T11:01:37.247794Z", + "start_time": "2020-10-13T11:01:37.238271Z" } }, "outputs": [], @@ -1331,8 +1354,8 @@ "execution_count": 16, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:03.511459Z", - "start_time": "2020-10-13T05:57:03.438766Z" + "end_time": "2020-10-13T11:01:37.300329Z", + "start_time": "2020-10-13T11:01:37.249243Z" } }, "outputs": [], @@ -1359,8 +1382,8 @@ "execution_count": 17, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:13.177176Z", - "start_time": "2020-10-13T05:57:03.512890Z" + "end_time": "2020-10-13T11:01:49.416804Z", + "start_time": "2020-10-13T11:01:37.301715Z" }, "scrolled": true }, @@ -1368,7 +1391,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "570d35ae09d4432b894cb59a847ad2b3", + "model_id": "9e5a344cc65441498d0141c7bc5e3e62", "version_major": 2, "version_minor": 0 }, @@ -1383,26 +1406,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "[1, 10] loss: 0.0071\n", - "[1, 20] loss: 0.00706\n", - "[1, 30] loss: 0.00698\n", - "[1, 40] loss: 0.00694\n", - "[1, 50] loss: 0.00688\n", - "[1, 60] loss: 0.00685\n", - "[1, 70] loss: 0.0068\n", - "[1, 80] loss: 0.00674\n", - "[1, 90] loss: 0.00672\n", - "[1, 100] loss: 0.00669\n", - "[1, 110] loss: 0.00663\n", - "[1, 120] loss: 0.00656\n", - "[1, 130] loss: 0.00653\n", - "[1, 140] loss: 0.00645\n", - "[1, 150] loss: 0.00642\n", - "[1, 160] loss: 0.00635\n", - "[1, 170] loss: 0.00627\n", - "[1, 180] loss: 0.00619\n", - "[1, 190] loss: 0.0061\n", - "[1, 200] loss: 0.00606\n", + "[1, 10] loss: 0.00687\n", + "[1, 20] loss: 0.00681\n", + "[1, 30] loss: 0.00681\n", + "[1, 40] loss: 0.00675\n", + "[1, 50] loss: 0.00674\n", + "[1, 60] loss: 0.00669\n", + "[1, 70] loss: 0.00663\n", + "[1, 80] loss: 0.00661\n", + "[1, 90] loss: 0.00663\n", + "[1, 100] loss: 0.00655\n", + "[1, 110] loss: 0.00651\n", + "[1, 120] loss: 0.00642\n", + "[1, 130] loss: 0.00643\n", + "[1, 140] loss: 0.00636\n", + "[1, 150] loss: 0.00628\n", + "[1, 160] loss: 0.00625\n", + "[1, 170] loss: 0.00614\n", + "[1, 180] loss: 0.00603\n", + "[1, 190] loss: 0.00602\n", + "[1, 200] loss: 0.00599\n", "\n", "Finished Training\n", "2371 4417\n", @@ -1430,8 +1453,8 @@ "execution_count": 18, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:46.593680Z", - "start_time": "2020-10-13T05:57:13.178565Z" + "end_time": "2020-10-13T11:02:31.312762Z", + "start_time": "2020-10-13T11:01:49.418135Z" }, "scrolled": true }, @@ -1439,7 +1462,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "18161e9d03da4bc198a9bd90c9407a5b", + "model_id": "29c13920b3a7479988c9c36c46d78360", "version_major": 2, "version_minor": 0 }, @@ -1454,106 +1477,106 @@ "name": "stdout", "output_type": "stream", "text": [ - "[1, 10] loss: 0.00593\n", - "[1, 20] loss: 0.00591\n", + "[1, 10] loss: 0.00588\n", + "[1, 20] loss: 0.00584\n", "[1, 30] loss: 0.00607\n", - "[1, 40] loss: 0.00602\n", - "[1, 50] loss: 0.00605\n", - "[1, 60] loss: 0.00605\n", - "[1, 70] loss: 0.00591\n", - "[1, 80] loss: 0.00597\n", - "[1, 90] loss: 0.0061\n", - "[1, 100] loss: 0.00606\n", - "[1, 110] loss: 0.00604\n", - "[1, 120] loss: 0.00595\n", + "[1, 40] loss: 0.00597\n", + "[1, 50] loss: 0.00604\n", + "[1, 60] loss: 0.00601\n", + "[1, 70] loss: 0.00586\n", + "[1, 80] loss: 0.00596\n", + "[1, 90] loss: 0.00612\n", + "[1, 100] loss: 0.00603\n", + "[1, 110] loss: 0.00602\n", + "[1, 120] loss: 0.00591\n", "[1, 130] loss: 0.00604\n", - "[1, 140] loss: 0.00598\n", - "[1, 150] loss: 0.006\n", - "[1, 160] loss: 0.00599\n", - "[1, 170] loss: 0.00596\n", - "[1, 180] loss: 0.00594\n", - "[1, 190] loss: 0.00594\n", - "[1, 200] loss: 0.00598\n", - "[2, 10] loss: 0.00593\n", - "[2, 20] loss: 0.00591\n", + "[1, 140] loss: 0.00597\n", + "[1, 150] loss: 0.00594\n", + "[1, 160] loss: 0.00598\n", + "[1, 170] loss: 0.0059\n", + "[1, 180] loss: 0.00584\n", + "[1, 190] loss: 0.0059\n", + "[1, 200] loss: 0.00593\n", + "[2, 10] loss: 0.00588\n", + "[2, 20] loss: 0.00584\n", "[2, 30] loss: 0.00607\n", - "[2, 40] loss: 0.00602\n", - "[2, 50] loss: 0.00605\n", - "[2, 60] loss: 0.00605\n", - "[2, 70] loss: 0.00591\n", - "[2, 80] loss: 0.00597\n", - "[2, 90] loss: 0.0061\n", - "[2, 100] loss: 0.00606\n", - "[2, 110] loss: 0.00604\n", - "[2, 120] loss: 0.00595\n", + "[2, 40] loss: 0.00597\n", + "[2, 50] loss: 0.00604\n", + "[2, 60] loss: 0.00601\n", + "[2, 70] loss: 0.00586\n", + "[2, 80] loss: 0.00596\n", + "[2, 90] loss: 0.00612\n", + "[2, 100] loss: 0.00603\n", + "[2, 110] loss: 0.00602\n", + "[2, 120] loss: 0.00591\n", "[2, 130] loss: 0.00604\n", - "[2, 140] loss: 0.00598\n", - "[2, 150] loss: 0.006\n", - "[2, 160] loss: 0.00599\n", - "[2, 170] loss: 0.00596\n", - "[2, 180] loss: 0.00594\n", - "[2, 190] loss: 0.00594\n", - "[2, 200] loss: 0.00598\n", - "[3, 10] loss: 0.00593\n", - "[3, 20] loss: 0.00591\n", + "[2, 140] loss: 0.00597\n", + "[2, 150] loss: 0.00594\n", + "[2, 160] loss: 0.00598\n", + "[2, 170] loss: 0.0059\n", + "[2, 180] loss: 0.00584\n", + "[2, 190] loss: 0.0059\n", + "[2, 200] loss: 0.00593\n", + "[3, 10] loss: 0.00588\n", + "[3, 20] loss: 0.00584\n", "[3, 30] loss: 0.00607\n", - "[3, 40] loss: 0.00602\n", - "[3, 50] loss: 0.00605\n", - "[3, 60] loss: 0.00605\n", - "[3, 70] loss: 0.00591\n", - "[3, 80] loss: 0.00597\n", - "[3, 90] loss: 0.0061\n", - "[3, 100] loss: 0.00606\n", - "[3, 110] loss: 0.00604\n", - "[3, 120] loss: 0.00595\n", + "[3, 40] loss: 0.00597\n", + "[3, 50] loss: 0.00604\n", + "[3, 60] loss: 0.00601\n", + "[3, 70] loss: 0.00586\n", + "[3, 80] loss: 0.00596\n", + "[3, 90] loss: 0.00612\n", + "[3, 100] loss: 0.00603\n", + "[3, 110] loss: 0.00602\n", + "[3, 120] loss: 0.00591\n", "[3, 130] loss: 0.00604\n", - "[3, 140] loss: 0.00598\n", - "[3, 150] loss: 0.006\n", - "[3, 160] loss: 0.00599\n", - "[3, 170] loss: 0.00596\n", - "[3, 180] loss: 0.00594\n", - "[3, 190] loss: 0.00594\n", - "[3, 200] loss: 0.00598\n", - "[4, 10] loss: 0.00593\n", - "[4, 20] loss: 0.00591\n", + "[3, 140] loss: 0.00597\n", + "[3, 150] loss: 0.00594\n", + "[3, 160] loss: 0.00598\n", + "[3, 170] loss: 0.0059\n", + "[3, 180] loss: 0.00584\n", + "[3, 190] loss: 0.0059\n", + "[3, 200] loss: 0.00593\n", + "[4, 10] loss: 0.00588\n", + "[4, 20] loss: 0.00584\n", "[4, 30] loss: 0.00607\n", - "[4, 40] loss: 0.00602\n", - "[4, 50] loss: 0.00605\n", - "[4, 60] loss: 0.00605\n", - "[4, 70] loss: 0.00591\n", - "[4, 80] loss: 0.00597\n", - "[4, 90] loss: 0.0061\n", - "[4, 100] loss: 0.00606\n", - "[4, 110] loss: 0.00604\n", - "[4, 120] loss: 0.00595\n", + "[4, 40] loss: 0.00597\n", + "[4, 50] loss: 0.00604\n", + "[4, 60] loss: 0.00601\n", + "[4, 70] loss: 0.00586\n", + "[4, 80] loss: 0.00596\n", + "[4, 90] loss: 0.00612\n", + "[4, 100] loss: 0.00603\n", + "[4, 110] loss: 0.00602\n", + "[4, 120] loss: 0.00591\n", "[4, 130] loss: 0.00604\n", - "[4, 140] loss: 0.00598\n", - "[4, 150] loss: 0.006\n", - "[4, 160] loss: 0.00599\n", - "[4, 170] loss: 0.00596\n", - "[4, 180] loss: 0.00594\n", - "[4, 190] loss: 0.00594\n", - "[4, 200] loss: 0.00598\n", - "[5, 10] loss: 0.00593\n", - "[5, 20] loss: 0.00591\n", + "[4, 140] loss: 0.00597\n", + "[4, 150] loss: 0.00594\n", + "[4, 160] loss: 0.00598\n", + "[4, 170] loss: 0.0059\n", + "[4, 180] loss: 0.00584\n", + "[4, 190] loss: 0.0059\n", + "[4, 200] loss: 0.00593\n", + "[5, 10] loss: 0.00588\n", + "[5, 20] loss: 0.00584\n", "[5, 30] loss: 0.00607\n", - "[5, 40] loss: 0.00602\n", - "[5, 50] loss: 0.00605\n", - "[5, 60] loss: 0.00605\n", - "[5, 70] loss: 0.00591\n", - "[5, 80] loss: 0.00597\n", - "[5, 90] loss: 0.0061\n", - "[5, 100] loss: 0.00606\n", - "[5, 110] loss: 0.00604\n", - "[5, 120] loss: 0.00595\n", + "[5, 40] loss: 0.00597\n", + "[5, 50] loss: 0.00604\n", + "[5, 60] loss: 0.00601\n", + "[5, 70] loss: 0.00586\n", + "[5, 80] loss: 0.00596\n", + "[5, 90] loss: 0.00612\n", + "[5, 100] loss: 0.00603\n", + "[5, 110] loss: 0.00602\n", + "[5, 120] loss: 0.00591\n", "[5, 130] loss: 0.00604\n", - "[5, 140] loss: 0.00598\n", - "[5, 150] loss: 0.006\n", - "[5, 160] loss: 0.00599\n", - "[5, 170] loss: 0.00596\n", - "[5, 180] loss: 0.00594\n", - "[5, 190] loss: 0.00594\n", - "[5, 200] loss: 0.00598\n", + "[5, 140] loss: 0.00597\n", + "[5, 150] loss: 0.00594\n", + "[5, 160] loss: 0.00598\n", + "[5, 170] loss: 0.0059\n", + "[5, 180] loss: 0.00584\n", + "[5, 190] loss: 0.0059\n", + "[5, 200] loss: 0.00593\n", "\n", "Finished Training\n", "Testing accuracy on unseen data...\n", @@ -1587,8 +1610,8 @@ "execution_count": 19, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:57:46.604546Z", - "start_time": "2020-10-13T05:57:46.595221Z" + "end_time": "2020-10-13T11:02:31.323080Z", + "start_time": "2020-10-13T11:02:31.314288Z" } }, "outputs": [], @@ -1647,11 +1670,11 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:58:30.756771Z", - "start_time": "2020-10-13T05:58:20.999118Z" + "end_time": "2020-10-13T11:03:06.466022Z", + "start_time": "2020-10-13T11:02:31.325667Z" }, "scrolled": true }, @@ -1659,7 +1682,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "87d07c458f3b4b41a4efe0eeaa7736db", + "model_id": "5c393b5324ce4659ade19c41a87d5ce8", "version_major": 2, "version_minor": 0 }, @@ -1674,44 +1697,56 @@ "name": "stdout", "output_type": "stream", "text": [ - "[1, 10] loss: 0.0575\n", - "[1, 20] loss: 0.00931\n", - "[1, 30] loss: 0.00228\n", - "[1, 40] loss: 0.00118\n", - "[1, 50] loss: 0.000458\n", - "[1, 60] loss: 0.000296\n", - "[1, 70] loss: 0.000515\n", - "[1, 80] loss: 0.000351\n", - "\n" + "[1, 10] loss: 0.0611\n", + "[1, 20] loss: 0.00932\n", + "[1, 30] loss: 0.00369\n", + "[1, 40] loss: 0.000794\n", + "[1, 50] loss: 0.000633\n", + "[1, 60] loss: 0.000391\n", + "[1, 70] loss: 0.000446\n", + "[1, 80] loss: 0.000202\n", + "[1, 90] loss: 0.000298\n", + "[1, 100] loss: 0.000156\n", + "[1, 110] loss: 0.000175\n", + "[1, 120] loss: 8.21e-05\n", + "[1, 130] loss: 0.000153\n", + "[1, 140] loss: 5.85e-05\n", + "[1, 150] loss: 6.75e-05\n", + "[1, 160] loss: 3.3e-05\n", + "[1, 170] loss: 0.000117\n", + "[1, 180] loss: 6.06e-05\n", + "[1, 190] loss: 0.000263\n", + "[1, 200] loss: 0.000141\n", + "\n", + "Finished Training\n", + "2613 4417\n" ] }, { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mconvnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBetterCNN\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdam\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvnet\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvnet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_train\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, x, y, criterion, optimizer, n_epochs, bs)\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;31m# print statistics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m \u001b[0mrunning_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 25\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m10\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m9\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"[%d, %5d] loss: %.3g\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_loss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m2000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] + "data": { + "text/plain": [ + "59.157799411365176" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "convnet = BetterCNN().to(device)\n", "optimizer = torch.optim.Adam(convnet.parameters(), lr=learning_rate)\n", - "model = train(convnet, x_train, y_train, criterion, optimizer)\n", - "test(model, x_test, y_test)" + "convnet = train(convnet, x_train, y_train, criterion, optimizer)\n", + "test(convnet, x_test, y_test)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T05:58:31.862332Z", - "start_time": "2020-10-13T05:58:31.829274Z" + "end_time": "2020-10-13T11:03:06.510194Z", + "start_time": "2020-10-13T11:03:06.468783Z" } }, "outputs": [ @@ -1772,7 +1807,7 @@ "1" ] }, - "execution_count": 22, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -1821,14 +1856,248 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T06:00:52.239227Z", - "start_time": "2020-10-13T05:59:23.760111Z" + "end_time": "2020-10-13T11:08:21.574558Z", + "start_time": "2020-10-13T11:03:52.622772Z" } }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dbc78f3990644a4b94c94c6618ca47f9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 10] loss: 0.0546\n", + "[1, 20] loss: 0.00867\n", + "[1, 30] loss: 0.00245\n", + "[1, 40] loss: 0.00125\n", + "[1, 50] loss: 0.00066\n", + "[1, 60] loss: 0.000289\n", + "[1, 70] loss: 0.000436\n", + "[1, 80] loss: 0.000251\n", + "[1, 90] loss: 0.000247\n", + "[1, 100] loss: 0.000156\n", + "[1, 110] loss: 0.000177\n", + "[1, 120] loss: 8.18e-05\n", + "[1, 130] loss: 0.000297\n", + "[1, 140] loss: 0.000167\n", + "[1, 150] loss: 7.41e-05\n", + "[1, 160] loss: 8.3e-05\n", + "[1, 170] loss: 0.000166\n", + "[1, 180] loss: 7.77e-05\n", + "[1, 190] loss: 0.000213\n", + "[1, 200] loss: 0.000132\n", + "[2, 10] loss: 0.011\n", + "[2, 20] loss: 0.00274\n", + "[2, 30] loss: 0.00131\n", + "[2, 40] loss: 0.000886\n", + "[2, 50] loss: 0.000951\n", + "[2, 60] loss: 0.00095\n", + "[2, 70] loss: 0.000232\n", + "[2, 80] loss: 0.000986\n", + "[2, 90] loss: 0.000333\n", + "[2, 100] loss: 0.000176\n", + "[2, 110] loss: 0.000287\n", + "[2, 120] loss: 4.85e-05\n", + "[2, 130] loss: 0.000151\n", + "[2, 140] loss: 0.000133\n", + "[2, 150] loss: 8.96e-05\n", + "[2, 160] loss: 5.17e-05\n", + "[2, 170] loss: 0.000309\n", + "[2, 180] loss: 0.000299\n", + "[2, 190] loss: 0.000136\n", + "[2, 200] loss: 0.000249\n", + "[3, 10] loss: 0.0153\n", + "[3, 20] loss: 0.00479\n", + "[3, 30] loss: 0.0032\n", + "[3, 40] loss: 0.000814\n", + "[3, 50] loss: 0.000905\n", + "[3, 60] loss: 0.000566\n", + "[3, 70] loss: 0.00085\n", + "[3, 80] loss: 0.000827\n", + "[3, 90] loss: 0.000585\n", + "[3, 100] loss: 0.000284\n", + "[3, 110] loss: 0.000662\n", + "[3, 120] loss: 4.97e-05\n", + "[3, 130] loss: 0.000122\n", + "[3, 140] loss: 8.32e-06\n", + "[3, 150] loss: 6.13e-05\n", + "[3, 160] loss: 1.46e-05\n", + "[3, 170] loss: 0.000213\n", + "[3, 180] loss: 2.82e-05\n", + "[3, 190] loss: 0.000224\n", + "[3, 200] loss: 0.000181\n", + "[4, 10] loss: 0.0216\n", + "[4, 20] loss: 0.00806\n", + "[4, 30] loss: 0.00127\n", + "[4, 40] loss: 0.00103\n", + "[4, 50] loss: 0.000891\n", + "[4, 60] loss: 0.000826\n", + "[4, 70] loss: 0.000115\n", + "[4, 80] loss: 0.00058\n", + "[4, 90] loss: 8.08e-05\n", + "[4, 100] loss: 0.000193\n", + "[4, 110] loss: 0.000151\n", + "[4, 120] loss: 4.11e-08\n", + "[4, 130] loss: 0.000209\n", + "[4, 140] loss: 0.000189\n", + "[4, 150] loss: 2e-05\n", + "[4, 160] loss: 0.000213\n", + "[4, 170] loss: 0.000261\n", + "[4, 180] loss: 0.0002\n", + "[4, 190] loss: 0.000319\n", + "[4, 200] loss: 0.000222\n", + "[5, 10] loss: 0.00438\n", + "[5, 20] loss: 0.00259\n", + "[5, 30] loss: 0.00053\n", + "[5, 40] loss: 0.000287\n", + "[5, 50] loss: 0.000234\n", + "[5, 60] loss: 8.15e-05\n", + "[5, 70] loss: 0.000467\n", + "[5, 80] loss: 0.000262\n", + "[5, 90] loss: 6.61e-05\n", + "[5, 100] loss: 0.000319\n", + "[5, 110] loss: 0.000169\n", + "[5, 120] loss: 1.92e-05\n", + "[5, 130] loss: 0.000249\n", + "[5, 140] loss: 8.32e-06\n", + "[5, 150] loss: 0.000249\n", + "[5, 160] loss: 5.6e-05\n", + "[5, 170] loss: 0.000238\n", + "[5, 180] loss: 0.000165\n", + "[5, 190] loss: 0.000272\n", + "[5, 200] loss: 0.000134\n", + "[6, 10] loss: 0.000165\n", + "[6, 20] loss: 0.000151\n", + "[6, 30] loss: 0.000108\n", + "[6, 40] loss: 9.68e-05\n", + "[6, 50] loss: 7.55e-05\n", + "[6, 60] loss: 0.000193\n", + "[6, 70] loss: 0.000394\n", + "[6, 80] loss: 0.000145\n", + "[6, 90] loss: 3.32e-07\n", + "[6, 100] loss: 9.72e-05\n", + "[6, 110] loss: 0.00016\n", + "[6, 120] loss: 6.4e-08\n", + "[6, 130] loss: 0.000136\n", + "[6, 140] loss: 0.0001\n", + "[6, 150] loss: 0.000105\n", + "[6, 160] loss: 2.78e-07\n", + "[6, 170] loss: 0.000121\n", + "[6, 180] loss: 0.000133\n", + "[6, 190] loss: 0.000363\n", + "[6, 200] loss: 0.00019\n", + "[7, 10] loss: 0.0112\n", + "[7, 20] loss: 0.00327\n", + "[7, 30] loss: 0.000902\n", + "[7, 40] loss: 0.00182\n", + "[7, 50] loss: 0.00107\n", + "[7, 60] loss: 0.000166\n", + "[7, 70] loss: 0.000411\n", + "[7, 80] loss: 0.000847\n", + "[7, 90] loss: 0.000389\n", + "[7, 100] loss: 0.00119\n", + "[7, 110] loss: 0.00355\n", + "[7, 120] loss: 0.000582\n", + "[7, 130] loss: 0.000382\n", + "[7, 140] loss: 0.000217\n", + "[7, 150] loss: 2.56e-05\n", + "[7, 160] loss: 9.63e-07\n", + "[7, 170] loss: 0.000343\n", + "[7, 180] loss: 8.3e-07\n", + "[7, 190] loss: 0.0001\n", + "[7, 200] loss: 0.000439\n", + "[8, 10] loss: 0.000264\n", + "[8, 20] loss: 0.00025\n", + "[8, 30] loss: 0.000189\n", + "[8, 40] loss: 5.74e-05\n", + "[8, 50] loss: 2.1e-06\n", + "[8, 60] loss: 8.72e-08\n", + "[8, 70] loss: 4.47e-11\n", + "[8, 80] loss: 0.000159\n", + "[8, 90] loss: 1.49e-11\n", + "[8, 100] loss: 0.00045\n", + "[8, 110] loss: 5.63e-05\n", + "[8, 120] loss: 9.62e-06\n", + "[8, 130] loss: 0.000179\n", + "[8, 140] loss: 3.87e-10\n", + "[8, 150] loss: 0.000183\n", + "[8, 160] loss: 5.96e-11\n", + "[8, 170] loss: 4.68e-06\n", + "[8, 180] loss: 8.6e-05\n", + "[8, 190] loss: 1.34e-05\n", + "[8, 200] loss: 0.000313\n", + "[9, 10] loss: 3.13e-05\n", + "[9, 20] loss: 0.000285\n", + "[9, 30] loss: 0.000305\n", + "[9, 40] loss: 0.000151\n", + "[9, 50] loss: 9.28e-05\n", + "[9, 60] loss: 5.5e-05\n", + "[9, 70] loss: 0.000122\n", + "[9, 80] loss: 8.49e-06\n", + "[9, 90] loss: 3.62e-09\n", + "[9, 100] loss: 5.83e-07\n", + "[9, 110] loss: 1.54e-06\n", + "[9, 120] loss: 1.64e-10\n", + "[9, 130] loss: 0.00011\n", + "[9, 140] loss: 0\n", + "[9, 150] loss: 1.04e-10\n", + "[9, 160] loss: 0\n", + "[9, 170] loss: 1.8e-07\n", + "[9, 180] loss: 7.72e-07\n", + "[9, 190] loss: 4.52e-09\n", + "[9, 200] loss: 0.000218\n", + "[10, 10] loss: 2.37e-05\n", + "[10, 20] loss: 8.21e-08\n", + "[10, 30] loss: 1.83e-06\n", + "[10, 40] loss: 0\n", + "[10, 50] loss: 0\n", + "[10, 60] loss: 2.38e-10\n", + "[10, 70] loss: 6.94e-05\n", + "[10, 80] loss: 1.04e-07\n", + "[10, 90] loss: 0\n", + "[10, 100] loss: 1.56e-05\n", + "[10, 110] loss: 0.000305\n", + "[10, 120] loss: 3.43e-09\n", + "[10, 130] loss: 1.27e-06\n", + "[10, 140] loss: 0\n", + "[10, 150] loss: 4.68e-09\n", + "[10, 160] loss: 0\n", + "[10, 170] loss: 2.74e-08\n", + "[10, 180] loss: 3.66e-06\n", + "[10, 190] loss: 8.94e-11\n", + "[10, 200] loss: 0.000204\n", + "\n", + "Finished Training\n", + "4412 4417\n" + ] + }, + { + "data": { + "text/plain": [ + "99.88680099615124" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [] }, { @@ -1840,11 +2109,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": { "ExecuteTime": { - "end_time": "2020-10-13T06:00:52.240942Z", - "start_time": "2020-10-13T05:59:24.354Z" + "end_time": "2020-10-13T11:03:06.518361Z", + "start_time": "2020-10-13T11:03:06.511841Z" } }, "outputs": [], diff --git a/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.py b/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.py index fd7fdc0..8454dc5 100644 --- a/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.py +++ b/notebooks/c02_Intro_to_NN_Part_2/Intro_to_NN_Part_2.py @@ -120,13 +120,14 @@ # Credits to researchers at Georgia Tech, Agile Geoscience # License CCbySA # -# In this notebook, we will be using the landmass dataset, which have been preprocessed already. In this dataset, we have images of 4 different types of landmass: 'Chaotic Horizon', 'Fault', 'Horizon', 'Salt Dome'. +# In this notebook, we will be using the landmass dataset, which have been preprocessed already. In this dataset, we have images of 4 different types of landmass: ['Discontinuous', 'Faulted', 'Continuous', 'Salt']. # # This is an example of [seismic data](https://en.wikipedia.org/wiki/Reflection_seismology) which is a way of using seismic to image the structure of the Earth, below the surface. These waves are similar to sounds waves in air. The lines represent changes in density below the surface. # # We will train a CNN to learn how to classify images into those 4 groups. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device # + # Let's import the Patches @@ -157,6 +158,10 @@ print("Class:", landmassf3_train.classes[y]) x +# Note that this is an unbalanced dataset, so we expect an accuracy of at least 52%, this is out baseline +labels = pd.Series(landmassf3_train.train_labels).replace(dict(enumerate(landmassf3_train.classes))) +labels.value_counts() / len(landmassf3_train) + landmassf3_train.classes # Source: [Neural Networks](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html#sphx-glr-beginner-blitz-neural-networks-tutorial-py) @@ -400,8 +405,8 @@ def forward(self, x): convnet = BetterCNN().to(device) optimizer = torch.optim.Adam(convnet.parameters(), lr=learning_rate) -model = train(convnet, x_train, y_train, criterion, optimizer) -test(model, x_test, y_test) +convnet = train(convnet, x_train, y_train, criterion, optimizer) +test(convnet, x_test, y_test) from deep_ml_curriculum.torchsummaryX import summary # We can also summarise the number of parameters in each layer