Skip to content

Commit

Permalink
Created using Colaboratory
Browse files Browse the repository at this point in the history
  • Loading branch information
harjeevanmaan committed Jun 17, 2020
1 parent 0bbc119 commit 4c79f81
Showing 1 changed file with 164 additions and 0 deletions.
164 changes: 164 additions & 0 deletions Semantic_Segmentation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Semantic_Segmentation",
"provenance": [],
"authorship_tag": "ABX9TyMbR1iCnULkuJHPmF2JiyeM",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/harjeevanmaan/SemanticSegmentation/blob/master/Semantic_Segmentation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "X4DTWmvAPlB2",
"colab_type": "code",
"colab": {}
},
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import tensorflow_datasets as tfds\n",
"import matplotlib.pyplot as plt\n",
"import logging\n",
"import PIL.Image as Image\n",
"!pip install git+https://github.com/tensorflow/examples.git\n",
"!pip install -U tfds-nightly\n",
"from tensorflow_examples.models.pix2pix import pix2pix\n",
"from IPython.display import clear_output\n",
"\n",
"tfds.disable_progress_bar()\n",
"logger = tf.get_logger()\n",
"logger.setLevel(logging.ERROR)\n",
"\n",
"!mkdir -p /content/downloads/manual/cityscapes #need to store data here to make use of the tfds.load function\n",
"!wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=my_username&password=my_password&submit=Login' https://www.cityscapes-dataset.com/login/\n",
"#need to use your own username and password to download the cityscapes dataset\n",
"!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -O /content/downloads/manual/cityscapes/leftImg8bit_trainvaltest.zip\n",
"!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -O /content/downloads/manual/cityscapes/gtFine_trainvaltest.zip\n",
"\n",
"cscape, info = tfds.load('cityscapes', data_dir=\"/content\", with_info=True)\n",
"\n",
"IMAGE_RES = 128\n",
"BATCH_SIZE = 8\n",
"OUTPUT_CHANNELS = 35\n",
"EPOCHS = 20\n",
"VAL_SUBSPLITS = 5\n",
"\n",
"def normalize(input_image, input_mask):\n",
" input_image = tf.cast(input_image, tf.float32) / 255.0\n",
" return input_image, input_mask\n",
"\n",
"def load_image_train(datapoint):\n",
" input_image = tf.image.resize(datapoint['image_left'], (IMAGE_RES, IMAGE_RES))\n",
" input_mask = tf.image.resize(datapoint['segmentation_label'], (IMAGE_RES, IMAGE_RES))\n",
" \n",
" if tf.random.uniform(()) > 0.5:\n",
" input_image = tf.image.flip_left_right(input_image)\n",
" input_mask = tf.image.flip_left_right(input_mask)\n",
"\n",
" input_image, input_mask = normalize(input_image, input_mask)\n",
" return input_image, input_mask\n",
" \n",
"def load_image_test(datapoint):\n",
" input_image = tf.image.resize(datapoint['image_left'], (IMAGE_RES, IMAGE_RES))\n",
" input_mask = tf.image.resize(datapoint['segmentation_label'], (IMAGE_RES, IMAGE_RES))\n",
"\n",
" input_image, input_mask = normalize(input_image, input_mask)\n",
" return input_image, input_mask\n",
"\n",
"tr_num = info.splits['train'].num_examples\n",
"val_num = info.splits['validation'].num_examples\n",
"test_num = info.splits['test'].num_examples\n",
"train = cscape['train'].map(load_image_train)\n",
"\n",
"tr_batches = cscape['train'].map(load_image_train).batch(BATCH_SIZE).shuffle(BATCH_SIZE).repeat().prefetch(1)\n",
"test_batches = cscape['test'].map(load_image_test).batch(BATCH_SIZE).prefetch(1)\n",
"\n",
"STEPS_PER_EPOCH = tr_num//BATCH_SIZE\n",
"\n",
"base_model = tf.keras.applications.MobileNetV2(input_shape=[IMAGE_RES, IMAGE_RES, 3], include_top=False)\n",
"\n",
"layer_names = [\n",
" 'block_1_expand_relu',\n",
" 'block_3_expand_relu',\n",
" 'block_6_expand_relu',\n",
" 'block_13_expand_relu',\n",
" 'block_16_expand_relu',\n",
"]\n",
"layers = [base_model.get_layer(name).output for name in layer_names]\n",
"\n",
"down_stack= tf.keras.Model(inputs=base_model.input, outputs=layers)\n",
"down_stack.trainable = False\n",
"\n",
"up_stack = [\n",
" pix2pix.upsample(512, 3),\n",
" pix2pix.upsample(256, 3),\n",
" pix2pix.upsample(128, 3),\n",
" pix2pix.upsample(64, 3),\n",
"]\n",
"\n",
"def unet_model(output_channels):\n",
" inputs = tf.keras.layers.Input(shape=[IMAGE_RES, IMAGE_RES, 3])\n",
" x = inputs\n",
"\n",
" skips = down_stack(x)\n",
" x = skips[-1]\n",
" skips = reversed(skips[:-1])\n",
"\n",
" for up, skip in zip(up_stack, skips):\n",
" x = up(x)\n",
" concat = tf.keras.layers.Concatenate()\n",
" x = concat([x, skip])\n",
"\n",
" last = tf.keras.layers.Conv2DTranspose(\n",
" output_channels, 3, strides=2,\n",
" padding='same')\n",
" x = last(x)\n",
" return tf.keras.Model(inputs=inputs, outputs=x)\n",
"\n",
"model = unet_model(OUTPUT_CHANNELS)\n",
"model.compile(optimizer='adam',\n",
" loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
" metrics=['accuracy'])\n",
"\n",
"def create_mask(pred_mask):\n",
" pred_mask = tf.argmax(pred_mask, axis=-1)\n",
" pred_mask = pred_mask[..., tf.newaxis]\n",
" return pred_mask[0]\n",
"\n",
"class DisplayCallback(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" clear_output(wait=True)\n",
" show_predictions()\n",
" print('\\nSample Predictions after epoch {}\\n'.format(epoch+1))\n",
"\n",
"VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n",
"\n",
"model_history = model.fit(tr_batches, epochs=EPOCHS,\n",
" steps_per_epoch = STEPS_PER_EPOCH,\n",
" validation_steps = VALIDATION_STEPS,\n",
" validation_data = test_batches,\n",
" callbacks=[DisplayCallback()])"
],
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit 4c79f81

Please sign in to comment.