-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0bbc119
commit 4c79f81
Showing
1 changed file
with
164 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "Semantic_Segmentation", | ||
"provenance": [], | ||
"authorship_tag": "ABX9TyMbR1iCnULkuJHPmF2JiyeM", | ||
"include_colab_link": true | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "view-in-github", | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"<a href=\"https://colab.research.google.com/github/harjeevanmaan/SemanticSegmentation/blob/master/Semantic_Segmentation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "X4DTWmvAPlB2", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"import numpy as np\n", | ||
"import tensorflow_datasets as tfds\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import logging\n", | ||
"import PIL.Image as Image\n", | ||
"!pip install git+https://github.com/tensorflow/examples.git\n", | ||
"!pip install -U tfds-nightly\n", | ||
"from tensorflow_examples.models.pix2pix import pix2pix\n", | ||
"from IPython.display import clear_output\n", | ||
"\n", | ||
"tfds.disable_progress_bar()\n", | ||
"logger = tf.get_logger()\n", | ||
"logger.setLevel(logging.ERROR)\n", | ||
"\n", | ||
"!mkdir -p /content/downloads/manual/cityscapes #need to store data here to make use of the tfds.load function\n", | ||
"!wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=my_username&password=my_password&submit=Login' https://www.cityscapes-dataset.com/login/\n", | ||
"#need to use your own username and password to download the cityscapes dataset\n", | ||
"!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 -O /content/downloads/manual/cityscapes/leftImg8bit_trainvaltest.zip\n", | ||
"!wget --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=1 -O /content/downloads/manual/cityscapes/gtFine_trainvaltest.zip\n", | ||
"\n", | ||
"cscape, info = tfds.load('cityscapes', data_dir=\"/content\", with_info=True)\n", | ||
"\n", | ||
"IMAGE_RES = 128\n", | ||
"BATCH_SIZE = 8\n", | ||
"OUTPUT_CHANNELS = 35\n", | ||
"EPOCHS = 20\n", | ||
"VAL_SUBSPLITS = 5\n", | ||
"\n", | ||
"def normalize(input_image, input_mask):\n", | ||
" input_image = tf.cast(input_image, tf.float32) / 255.0\n", | ||
" return input_image, input_mask\n", | ||
"\n", | ||
"def load_image_train(datapoint):\n", | ||
" input_image = tf.image.resize(datapoint['image_left'], (IMAGE_RES, IMAGE_RES))\n", | ||
" input_mask = tf.image.resize(datapoint['segmentation_label'], (IMAGE_RES, IMAGE_RES))\n", | ||
" \n", | ||
" if tf.random.uniform(()) > 0.5:\n", | ||
" input_image = tf.image.flip_left_right(input_image)\n", | ||
" input_mask = tf.image.flip_left_right(input_mask)\n", | ||
"\n", | ||
" input_image, input_mask = normalize(input_image, input_mask)\n", | ||
" return input_image, input_mask\n", | ||
" \n", | ||
"def load_image_test(datapoint):\n", | ||
" input_image = tf.image.resize(datapoint['image_left'], (IMAGE_RES, IMAGE_RES))\n", | ||
" input_mask = tf.image.resize(datapoint['segmentation_label'], (IMAGE_RES, IMAGE_RES))\n", | ||
"\n", | ||
" input_image, input_mask = normalize(input_image, input_mask)\n", | ||
" return input_image, input_mask\n", | ||
"\n", | ||
"tr_num = info.splits['train'].num_examples\n", | ||
"val_num = info.splits['validation'].num_examples\n", | ||
"test_num = info.splits['test'].num_examples\n", | ||
"train = cscape['train'].map(load_image_train)\n", | ||
"\n", | ||
"tr_batches = cscape['train'].map(load_image_train).batch(BATCH_SIZE).shuffle(BATCH_SIZE).repeat().prefetch(1)\n", | ||
"test_batches = cscape['test'].map(load_image_test).batch(BATCH_SIZE).prefetch(1)\n", | ||
"\n", | ||
"STEPS_PER_EPOCH = tr_num//BATCH_SIZE\n", | ||
"\n", | ||
"base_model = tf.keras.applications.MobileNetV2(input_shape=[IMAGE_RES, IMAGE_RES, 3], include_top=False)\n", | ||
"\n", | ||
"layer_names = [\n", | ||
" 'block_1_expand_relu',\n", | ||
" 'block_3_expand_relu',\n", | ||
" 'block_6_expand_relu',\n", | ||
" 'block_13_expand_relu',\n", | ||
" 'block_16_expand_relu',\n", | ||
"]\n", | ||
"layers = [base_model.get_layer(name).output for name in layer_names]\n", | ||
"\n", | ||
"down_stack= tf.keras.Model(inputs=base_model.input, outputs=layers)\n", | ||
"down_stack.trainable = False\n", | ||
"\n", | ||
"up_stack = [\n", | ||
" pix2pix.upsample(512, 3),\n", | ||
" pix2pix.upsample(256, 3),\n", | ||
" pix2pix.upsample(128, 3),\n", | ||
" pix2pix.upsample(64, 3),\n", | ||
"]\n", | ||
"\n", | ||
"def unet_model(output_channels):\n", | ||
" inputs = tf.keras.layers.Input(shape=[IMAGE_RES, IMAGE_RES, 3])\n", | ||
" x = inputs\n", | ||
"\n", | ||
" skips = down_stack(x)\n", | ||
" x = skips[-1]\n", | ||
" skips = reversed(skips[:-1])\n", | ||
"\n", | ||
" for up, skip in zip(up_stack, skips):\n", | ||
" x = up(x)\n", | ||
" concat = tf.keras.layers.Concatenate()\n", | ||
" x = concat([x, skip])\n", | ||
"\n", | ||
" last = tf.keras.layers.Conv2DTranspose(\n", | ||
" output_channels, 3, strides=2,\n", | ||
" padding='same')\n", | ||
" x = last(x)\n", | ||
" return tf.keras.Model(inputs=inputs, outputs=x)\n", | ||
"\n", | ||
"model = unet_model(OUTPUT_CHANNELS)\n", | ||
"model.compile(optimizer='adam',\n", | ||
" loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", | ||
" metrics=['accuracy'])\n", | ||
"\n", | ||
"def create_mask(pred_mask):\n", | ||
" pred_mask = tf.argmax(pred_mask, axis=-1)\n", | ||
" pred_mask = pred_mask[..., tf.newaxis]\n", | ||
" return pred_mask[0]\n", | ||
"\n", | ||
"class DisplayCallback(tf.keras.callbacks.Callback):\n", | ||
" def on_epoch_end(self, epoch, logs=None):\n", | ||
" clear_output(wait=True)\n", | ||
" show_predictions()\n", | ||
" print('\\nSample Predictions after epoch {}\\n'.format(epoch+1))\n", | ||
"\n", | ||
"VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n", | ||
"\n", | ||
"model_history = model.fit(tr_batches, epochs=EPOCHS,\n", | ||
" steps_per_epoch = STEPS_PER_EPOCH,\n", | ||
" validation_steps = VALIDATION_STEPS,\n", | ||
" validation_data = test_batches,\n", | ||
" callbacks=[DisplayCallback()])" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |