diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..744b649 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/logs/ +/0try/ +/_pycache_/ +/Semantic_Segmentation_Dataset/ + diff --git a/README.md b/README.md index 39af52c..6abc0bd 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,20 @@ -# README # +Instructions: -This README would normally document whatever steps are necessary to get your application up and running. +To train the model + +```python train.py --help``` +eg. -### What is this repository for? ### +```python train.py --model densenet --expname FINAL --bs 4 --useGPU True --dataset Semantic_Segmentation_Dataset/``` -* Quick summary -* Version -* [Learn Markdown](https://bitbucket.org/tutorials/markdowndemo) -### How do I get set up? ### +To test the result: + +```python test.py --model densenet --load best_model.pkl --bs 4 --dataset Semantic_Segmentation_Dataset/``` -* Summary of set up -* Configuration -* Dependencies -* Database configuration -* How to run tests -* Deployment instructions -### Contribution guidelines ### -* Writing tests -* Code review -* Other guidelines +The requirements.txt file contains all the packages necessary for the code to run. We have also included an environment.yml file of the system which runs the code successfully. Please refer to that file if there is an error with specific packages. + -### Who do I talk to? ### -* Repo owner or admin -* Other community or team contact \ No newline at end of file diff --git a/Starburst generation from train image 000000240768.pdf b/Starburst generation from train image 000000240768.pdf new file mode 100644 index 0000000..4333918 Binary files /dev/null and b/Starburst generation from train image 000000240768.pdf differ diff --git a/__pycache__/dataset.cpython-36.pyc b/__pycache__/dataset.cpython-36.pyc new file mode 100644 index 0000000..de46c6e Binary files /dev/null and b/__pycache__/dataset.cpython-36.pyc differ diff --git a/__pycache__/densenet.cpython-36.pyc b/__pycache__/densenet.cpython-36.pyc new file mode 100644 index 0000000..50270d1 Binary files /dev/null and b/__pycache__/densenet.cpython-36.pyc differ diff --git a/__pycache__/models.cpython-36.pyc b/__pycache__/models.cpython-36.pyc new file mode 100644 index 0000000..57b1794 Binary files /dev/null and b/__pycache__/models.cpython-36.pyc differ diff --git a/__pycache__/opt.cpython-36.pyc b/__pycache__/opt.cpython-36.pyc new file mode 100644 index 0000000..7749da3 Binary files /dev/null and b/__pycache__/opt.cpython-36.pyc differ diff --git a/__pycache__/utils.cpython-36.pyc b/__pycache__/utils.cpython-36.pyc new file mode 100644 index 0000000..e7f5598 Binary files /dev/null and b/__pycache__/utils.cpython-36.pyc differ diff --git a/best_model.pkl b/best_model.pkl new file mode 100644 index 0000000..f0864e6 Binary files /dev/null and b/best_model.pkl differ diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..1dac5f7 --- /dev/null +++ b/dataset.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 2 11:47:44 2019 + +@author: Aayush + +This file contains the dataloader and the augmentations and preprocessing done + +Required Preprocessing for all images (test, train and validation set): +1) Gamma correction by a factor of 0.8 +2) local Contrast limited adaptive histogram equalization algorithm with clipLimit=1.5, tileGridSize=(8,8) +3) Normalization + +Train Image Augmentation Procedure Followed +1) Random horizontal flip with 50% probability. +2) Starburst pattern augmentation with 20% probability. +3) Random length lines augmentation around a random center with 20% probability. +4) Gaussian blur with kernel size (7,7) and random sigma with 20% probability. +5) Translation of image and labels in any direction with random factor less than 20. +""" + +import numpy as np +import torch +from torch.utils.data import Dataset +import os +from PIL import Image +from torchvision import transforms +import cv2 +import random +import os.path as osp +from utils import one_hot2dist +import copy + +transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize([0.5], [0.5])]) + +#%% +class RandomHorizontalFlip(object): + def __call__(self, img,label): + if random.random() < 0.5: + return img.transpose(Image.FLIP_LEFT_RIGHT),\ + label.transpose(Image.FLIP_LEFT_RIGHT) + return img,label + +class Starburst_augment(object): + ## We have generated the starburst pattern from a train image 000000240768.png + ## Please follow the file Starburst_generation_from_train_image_000000240768.pdf attached in the folder + ## This procedure is used in order to handle people with multiple reflections for glasses + ## a random translation of mask of starburst pattern + def __call__(self, img): + x=np.random.randint(1, 40) + y=np.random.randint(1, 40) + mode = np.random.randint(0, 2) + starburst=Image.open('starburst_black.png').convert("L") + if mode == 0: + starburst = np.pad(starburst, pad_width=((0, 0), (x, 0)), mode='constant') + starburst = starburst[:, :-x] + if mode == 1: + starburst = np.pad(starburst, pad_width=((0, 0), (0, x)), mode='constant') + starburst = starburst[:, x:] + + img[92+y:549+y,0:400]=np.array(img)[92+y:549+y,0:400]*((255-np.array(starburst))/255)+np.array(starburst) + return Image.fromarray(img) + +def getRandomLine(xc, yc, theta): + x1 = xc - 50*np.random.rand(1)*(1 if np.random.rand(1) < 0.5 else -1) + y1 = (x1 - xc)*np.tan(theta) + yc + x2 = xc - (150*np.random.rand(1) + 50)*(1 if np.random.rand(1) < 0.5 else -1) + y2 = (x2 - xc)*np.tan(theta) + yc + return x1, y1, x2, y2 + +class Gaussian_blur(object): + def __call__(self, img): + sigma_value=np.random.randint(2, 7) + return Image.fromarray(cv2.GaussianBlur(img,(7,7),sigma_value)) + +class Translation(object): + def __call__(self, base,mask): + factor_h = 2*np.random.randint(1, 20) + factor_v = 2*np.random.randint(1, 20) + mode = np.random.randint(0, 4) +# print (mode,factor_h,factor_v) + if mode == 0: + aug_base = np.pad(base, pad_width=((factor_v, 0), (0, 0)), mode='constant') + aug_mask = np.pad(mask, pad_width=((factor_v, 0), (0, 0)), mode='constant') + aug_base = aug_base[:-factor_v, :] + aug_mask = aug_mask[:-factor_v, :] + if mode == 1: + aug_base = np.pad(base, pad_width=((0, factor_v), (0, 0)), mode='constant') + aug_mask = np.pad(mask, pad_width=((0, factor_v), (0, 0)), mode='constant') + aug_base = aug_base[factor_v:, :] + aug_mask = aug_mask[factor_v:, :] + if mode == 2: + aug_base = np.pad(base, pad_width=((0, 0), (factor_h, 0)), mode='constant') + aug_mask = np.pad(mask, pad_width=((0, 0), (factor_h, 0)), mode='constant') + aug_base = aug_base[:, :-factor_h] + aug_mask = aug_mask[:, :-factor_h] + if mode == 3: + aug_base = np.pad(base, pad_width=((0, 0), (0, factor_h)), mode='constant') + aug_mask = np.pad(mask, pad_width=((0, 0), (0, factor_h)), mode='constant') + aug_base = aug_base[:, factor_h:] + aug_mask = aug_mask[:, factor_h:] + return Image.fromarray(aug_base), Image.fromarray(aug_mask) + +class Line_augment(object): + def __call__(self, base): + yc, xc = (0.3 + 0.4*np.random.rand(1))*base.shape + aug_base = copy.deepcopy(base) + num_lines = np.random.randint(1, 10) + for i in np.arange(0, num_lines): + theta = np.pi*np.random.rand(1) + x1, y1, x2, y2 = getRandomLine(xc, yc, theta) + aug_base = cv2.line(aug_base, (x1, y1), (x2, y2), (255, 255, 255), 4) + aug_base = aug_base.astype(np.uint8) + return Image.fromarray(aug_base) + +class MaskToTensor(object): + def __call__(self, img): + return torch.from_numpy(np.array(img, dtype=np.int32)).long() + + +class IrisDataset(Dataset): + def __init__(self, filepath, split='train',transform=None,**args): + self.transform = transform + self.filepath= osp.join(filepath,split) + self.split = split + listall = [] + + for file in os.listdir(osp.join(self.filepath,'images')): + if file.endswith(".png"): + listall.append(file.strip(".png")) + self.list_files=listall + + self.testrun = args.get('testrun') + + #PREPROCESSING STEP FOR ALL TRAIN, VALIDATION AND TEST INPUTS + #local Contrast limited adaptive histogram equalization algorithm + self.clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8)) + + def __len__(self): + if self.testrun: + return 10 + return len(self.list_files) + + def __getitem__(self, idx): + imagepath = osp.join(self.filepath,'images',self.list_files[idx]+'.png') + pilimg = Image.open(imagepath).convert("L") + H, W = pilimg.width , pilimg.height + + #PREPROCESSING STEP FOR ALL TRAIN, VALIDATION AND TEST INPUTS + #Fixed gamma value for + table = 255.0*(np.linspace(0, 1, 256)**0.8) + pilimg = cv2.LUT(np.array(pilimg), table) + + + if self.split != 'test': + labelpath = osp.join(self.filepath,'labels',self.list_files[idx]+'.npy') + label = np.load(labelpath) + label = np.resize(label,(W,H)) + label = Image.fromarray(label) + + if self.transform is not None: + if self.split == 'train': + if random.random() < 0.2: + pilimg = Starburst_augment()(np.array(pilimg)) + if random.random() < 0.2: + pilimg = Line_augment()(np.array(pilimg)) + if random.random() < 0.2: + pilimg = Gaussian_blur()(np.array(pilimg)) + if random.random() < 0.4: + pilimg, label = Translation()(np.array(pilimg),np.array(label)) + + img = self.clahe.apply(np.array(np.uint8(pilimg))) + img = Image.fromarray(img) + + if self.transform is not None: + if self.split == 'train': + img, label = RandomHorizontalFlip()(img,label) + img = self.transform(img) + + + if self.split != 'test': + ## This is for boundary aware cross entropy calculation + spatialWeights = cv2.Canny(np.array(label),0,3)/255 + spatialWeights=cv2.dilate(spatialWeights,(3,3),iterations = 1)*20 + + ##This is the implementation for the surface loss + # Distance map for each class + distMap = [] + for i in range(0, 4): + distMap.append(one_hot2dist(np.array(label)==i)) + distMap = np.stack(distMap, 0) +# spatialWeights=np.float32(distMap) + + + if self.split == 'test': + ##since label, spatialWeights and distMap is not needed for test images + return img,0,self.list_files[idx],0,0 + + label = MaskToTensor()(label) + return img, label, self.list_files[idx],spatialWeights,np.float32(distMap) + +if __name__ == "__main__": + import matplotlib.pyplot as plt + ds = IrisDataset('Semantic_Segmentation_Dataset',split='train',transform=transform) +# for i in range(1000): + img, label, idx,x,y= ds[0] + plt.subplot(121) + plt.imshow(np.array(label)) + plt.subplot(122) + plt.imshow(np.array(img)[0,:,:],cmap='gray') \ No newline at end of file diff --git a/densenet.py b/densenet.py new file mode 100644 index 0000000..dbe5160 --- /dev/null +++ b/densenet.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 2 11:20:33 2019 + +@author: Shusil Dangi (Adopted from Shusil Dangi) +""" +import torch +import math +import torch.nn as nn +import torch.nn.functional as F + +class DenseNet2D_down_block(nn.Module): + def __init__(self,input_channels,output_channels,down_size,dropout=False,prob=0): + super(DenseNet2D_down_block, self).__init__() + self.conv1 = nn.Conv2d(input_channels,output_channels,kernel_size=(3,3),padding=(1,1)) + self.conv21 = nn.Conv2d(input_channels+output_channels,output_channels,kernel_size=(1,1),padding=(0,0)) + self.conv22 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) + self.conv31 = nn.Conv2d(input_channels+2*output_channels,output_channels,kernel_size=(1,1),padding=(0,0)) + self.conv32 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) + self.max_pool = nn.AvgPool2d(kernel_size=down_size) + + self.relu = nn.LeakyReLU() + self.down_size = down_size + self.dropout = dropout + self.dropout1 = nn.Dropout(p=prob) + self.dropout2 = nn.Dropout(p=prob) + self.dropout3 = nn.Dropout(p=prob) + self.bn = torch.nn.BatchNorm2d(num_features=output_channels) + + def forward(self, x): + if self.down_size != None: + x = self.max_pool(x) + + if self.dropout: + x1 = self.relu(self.dropout1(self.conv1(x))) + x21 = torch.cat((x,x1),dim=1) + x22 = self.relu(self.dropout2(self.conv22(self.conv21(x21)))) + x31 = torch.cat((x21,x22),dim=1) + out = self.relu(self.dropout3(self.conv32(self.conv31(x31)))) + else: + x1 = self.relu(self.conv1(x)) + x21 = torch.cat((x,x1),dim=1) + x22 = self.relu(self.conv22(self.conv21(x21))) + x31 = torch.cat((x21,x22),dim=1) + out = self.relu(self.conv32(self.conv31(x31))) + return self.bn(out) + + +class DenseNet2D_up_block_concat(nn.Module): + def __init__(self,skip_channels,input_channels,output_channels,up_stride,dropout=False,prob=0): + super(DenseNet2D_up_block_concat, self).__init__() + self.conv11 = nn.Conv2d(skip_channels+input_channels,output_channels,kernel_size=(1,1),padding=(0,0)) + self.conv12 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) + self.conv21 = nn.Conv2d(skip_channels+input_channels+output_channels,output_channels, + kernel_size=(1,1),padding=(0,0)) + self.conv22 = nn.Conv2d(output_channels,output_channels,kernel_size=(3,3),padding=(1,1)) + self.relu = nn.LeakyReLU() + self.up_stride = up_stride + self.dropout = dropout + self.dropout1 = nn.Dropout(p=prob) + self.dropout2 = nn.Dropout(p=prob) + + def forward(self,prev_feature_map,x): + x = nn.functional.interpolate(x,scale_factor=self.up_stride,mode='nearest') + x = torch.cat((x,prev_feature_map),dim=1) + if self.dropout: + x1 = self.relu(self.dropout1(self.conv12(self.conv11(x)))) + x21 = torch.cat((x,x1),dim=1) + out = self.relu(self.dropout2(self.conv22(self.conv21(x21)))) + else: + x1 = self.relu(self.conv12(self.conv11(x))) + x21 = torch.cat((x,x1),dim=1) + out = self.relu(self.conv22(self.conv21(x21))) + return out + +class DenseNet2D(nn.Module): + def __init__(self,in_channels=1,out_channels=4,channel_size=32,concat=True,dropout=False,prob=0): + super(DenseNet2D, self).__init__() + + self.down_block1 = DenseNet2D_down_block(input_channels=in_channels,output_channels=channel_size, + down_size=None,dropout=dropout,prob=prob) + self.down_block2 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, + down_size=(2,2),dropout=dropout,prob=prob) + self.down_block3 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, + down_size=(2,2),dropout=dropout,prob=prob) + self.down_block4 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, + down_size=(2,2),dropout=dropout,prob=prob) + self.down_block5 = DenseNet2D_down_block(input_channels=channel_size,output_channels=channel_size, + down_size=(2,2),dropout=dropout,prob=prob) + + self.up_block1 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, + output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) + self.up_block2 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, + output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) + self.up_block3 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, + output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) + self.up_block4 = DenseNet2D_up_block_concat(skip_channels=channel_size,input_channels=channel_size, + output_channels=channel_size,up_stride=(2,2),dropout=dropout,prob=prob) + + self.out_conv1 = nn.Conv2d(in_channels=channel_size,out_channels=out_channels,kernel_size=1,padding=0) + self.concat = concat + self.dropout = dropout + self.dropout1 = nn.Dropout(p=prob) + + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + n = m.weight.size(1) + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + def forward(self,x): + self.x1 = self.down_block1(x) + self.x2 = self.down_block2(self.x1) + self.x3 = self.down_block3(self.x2) + self.x4 = self.down_block4(self.x3) + self.x5 = self.down_block5(self.x4) + self.x6 = self.up_block1(self.x4,self.x5) + self.x7 = self.up_block2(self.x3,self.x6) + self.x8 = self.up_block3(self.x2,self.x7) + self.x9 = self.up_block4(self.x1,self.x8) + if self.dropout: + out = self.out_conv1(self.dropout1(self.x9)) + else: + out = self.out_conv1(self.x9) + + return out + diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..e5306f4 --- /dev/null +++ b/environment.yml @@ -0,0 +1,311 @@ +name: base +channels: + - anaconda + - menpo + - pytorch + - conda-forge + - defaults +dependencies: + - _ipyw_jlab_nb_ext_conf=0.1.0=py36he11e457_0 + - _libgcc_mutex=0.1=main + - alabaster=0.7.10=py36h306e16b_0 + - anaconda-client=1.6.14=py36_0 + - anaconda-navigator=1.8.7=py36_0 + - anaconda-project=0.8.2=py36h44fb852_0 + - asn1crypto=0.24.0=py36_0 + - astroid=1.6.3=py36_0 + - astropy=3.0.2=py36h3010b51_1 + - attrs=18.1.0=py36_0 + - av=6.0.0=py36h7273d18_0 + - babel=2.5.3=py36_0 + - backcall=0.1.0=py36_0 + - backports=1.0=py36hfa02d7e_1 + - backports.shutil_get_terminal_size=1.0.0=py36hfea85ff_2 + - beautifulsoup4=4.6.0=py36h49b8c8c_1 + - bitarray=0.8.1=py36h14c3975_1 + - bkcharts=0.2=py36h735825a_0 + - blas=1.0=mkl + - blaze=0.11.3=py36h4e06776_0 + - bleach=2.1.3=py36_0 + - blosc=1.14.3=hdbcaa40_0 + - bokeh=0.12.16=py36_0 + - boto=2.48.0=py36h6e4cd66_1 + - bottleneck=1.2.1=py36haac1ea0_0 + - bzip2=1.0.6=h14c3975_5 + - ca-certificates=2019.5.15=0 + - cairo=1.14.12=h7636065_2 + - certifi=2019.6.16=py36_1 + - cffi=1.12.3=py36h2e261b9_0 + - chardet=3.0.4=py36h0f667ec_1 + - click=6.7=py36h5253387_0 + - cloudpickle=0.5.3=py36_0 + - clyent=1.2.2=py36h7e57e65_1 + - colorama=0.3.9=py36h489cec4_0 + - conda=4.6.11=py36_0 + - conda-build=3.10.9=py36_0 + - conda-env=2.6.0=h36134e3_1 + - conda-verify=2.0.0=py36h98955d8_0 + - contextlib2=0.5.5=py36h6c84a62_0 + - cryptography=2.2.2=py36h14c3975_0 + - cuda75=1.0=hf2493ae_0 + - cudatoolkit=8.0=3 + - curl=7.60.0=h84994c4_0 + - cycler=0.10.0=py36h93f1223_0 + - cython=0.28.2=py36h14c3975_0 + - cytoolz=0.9.0.1=py36h14c3975_0 + - dask=0.17.5=py36_0 + - dask-core=0.17.5=py36_0 + - datashape=0.5.4=py36h3ad6b5c_0 + - dbus=1.13.2=h714fa37_1 + - decorator=4.3.0=py36_0 + - distributed=1.21.8=py36_0 + - docutils=0.14=py36hb0f60f5_0 + - entrypoints=0.2.3=py36h1aec115_2 + - et_xmlfile=1.0.1=py36hd6bccc3_0 + - expat=2.2.5=he0dffb1_0 + - fastcache=1.0.2=py36h14c3975_2 + - ffmpeg=4.0.2=ha6a6e2b_0 + - filelock=3.0.4=py36_0 + - flask=1.0.2=py36_1 + - flask-cors=3.0.4=py36_0 + - fontconfig=2.12.6=h49f89f6_0 + - freeglut=2.8.1=0 + - freetype=2.8.1=hfa320df_1 + - get_terminal_size=1.0.0=haa9412d_0 + - gevent=1.3.0=py36h14c3975_0 + - glib=2.56.1=h000015b_0 + - glob2=0.6=py36he249c77_0 + - gmp=6.1.2=h6c8ec71_1 + - gmpy2=2.0.8=py36hc8893dd_2 + - gnutls=3.5.19=h2a4e5f8_1 + - graphite2=1.3.11=h16798f4_2 + - greenlet=0.4.13=py36h14c3975_0 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - h5py=2.7.1=py36ha1f6525_2 + - harfbuzz=1.7.6=h5f0a787_1 + - hdf5=1.10.2=hba1933b_1 + - heapdict=1.0.0=py36_2 + - html5lib=1.0.1=py36h2f9c1c0_0 + - icu=58.2=h9c2bf20_1 + - idna=2.6=py36h82fb2a8_1 + - imageio=2.3.0=py36_0 + - imagesize=1.0.0=py36_0 + - intel-openmp=2018.0.0=8 + - ipykernel=4.8.2=py36_0 + - ipython=6.4.0=py36_0 + - ipython_genutils=0.2.0=py36hb52b0d5_0 + - ipywidgets=7.2.1=py36_0 + - isort=4.3.4=py36_0 + - itsdangerous=0.24=py36h93cc618_1 + - jasper=1.900.1=hd497a04_4 + - jbig=2.1=hdba287a_0 + - jdcal=1.4=py36_0 + - jedi=0.12.0=py36_1 + - jinja2=2.10=py36ha16c418_0 + - jpeg=9b=h024ee3a_2 + - jsonschema=2.6.0=py36h006f8b5_0 + - jupyter=1.0.0=py36_4 + - jupyter_client=5.2.3=py36_0 + - jupyter_console=5.2.0=py36he59e554_1 + - jupyter_core=4.4.0=py36h7c827e3_0 + - jupyterlab=0.32.1=py36_0 + - jupyterlab_launcher=0.10.5=py36_0 + - kiwisolver=1.0.1=py36h764f252_0 + - lazy-object-proxy=1.3.1=py36h10fcdad_0 + - libcurl=7.60.0=h1ad7b7a_0 + - libedit=3.1.20181209=hc058e9b_0 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=9.1.0=hdf63c60_0 + - libgfortran-ng=7.2.0=hdf63c60_3 + - libiconv=1.15=h470a237_3 + - libopencv=3.4.1=h1a3b859_1 + - libopus=1.2.1=hb9ed12e_0 + - libpng=1.6.34=hb9fc6fc_0 + - libprotobuf=3.5.2=h6f1eeef_0 + - libsodium=1.0.16=h1bed415_0 + - libssh2=1.8.0=h9cfc8f7_4 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.0.9=he85c1e1_1 + - libtool=2.4.6=h544aabb_3 + - libvpx=1.7.0=h439df22_0 + - libxcb=1.13=h1bed415_1 + - libxml2=2.9.8=h26e45fe_1 + - libxslt=1.1.32=h1312cb7_0 + - llvmlite=0.23.1=py36hdbcaa40_0 + - locket=0.2.0=py36h787c0ad_1 + - lxml=4.2.1=py36h23eabaa_0 + - lzo=2.10=h49e0be7_2 + - markupsafe=1.0=py36hd9260cd_1 + - matplotlib=2.2.2=py36h0e671d2_1 + - mccabe=0.6.1=py36h5ad9710_1 + - mistune=0.8.3=py36h14c3975_1 + - mkl=2018.0.3=1 + - mkl-service=1.1.2=py36h17a0993_4 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - more-itertools=4.1.0=py36_0 + - mpc=1.0.3=hec55b23_5 + - mpfr=3.1.5=h11a74b3_2 + - mpmath=1.0.0=py36hfeacd6b_2 + - msgpack-python=0.5.6=py36h6bb024c_0 + - multipledispatch=0.5.0=py36_0 + - navigator-updater=0.2.1=py36_0 + - nbconvert=5.3.1=py36hb41ffb7_0 + - nbformat=4.4.0=py36h31c9010_0 + - ncurses=6.1=he6710b0_1 + - nettle=3.3=0 + - networkx=2.1=py36_0 + - ninja=1.7.2=0 + - nltk=3.3.0=py36_0 + - nose=1.3.7=py36hcdf7029_2 + - notebook=5.5.0=py36_0 + - numba=0.38.0=py36h637b7d7_0 + - numexpr=2.6.5=py36h7bf3b9c_0 + - numpy=1.15.4=py36h1d66e8a_0 + - numpy-base=1.15.4=py36h81de0dd_0 + - numpydoc=0.8.0=py36_0 + - odo=0.5.1=py36h90ed295_0 + - olefile=0.45.1=py36_0 + - opencv3=3.1.0=py36_0 + - openh264=1.7.0=0 + - openpyxl=2.5.3=py36_0 + - openssl=1.1.1c=h7b6447c_1 + - packaging=17.1=py36_0 + - pandas=0.23.0=py36h637b7d7_0 + - pandoc=1.19.2.1=hea2e7c5_1 + - pandocfilters=1.4.2=py36ha6701b7_1 + - pango=1.41.0=hd475d92_0 + - parso=0.2.0=py36_0 + - partd=0.3.8=py36h36fd896_0 + - patchelf=0.9=hf79760b_2 + - path.py=11.0.1=py36_0 + - pathlib2=2.3.2=py36_0 + - patsy=0.5.0=py36_0 + - pcre=8.42=h439df22_0 + - pep8=1.7.1=py36_0 + - pexpect=4.5.0=py36_0 + - pickleshare=0.7.4=py36h63277f8_0 + - pillow=5.1.0=py36h3deb7b8_0 + - pixman=0.34.0=hceecf20_3 + - pkginfo=1.4.2=py36_1 + - pluggy=0.6.0=py36hb689045_0 + - ply=3.11=py36_0 + - prompt_toolkit=1.0.15=py36h17d85b1_0 + - psutil=5.4.5=py36h14c3975_0 + - ptyprocess=0.5.2=py36h69acd42_0 + - py=1.5.3=py36_0 + - py-opencv=3.4.1=py36h0676e08_1 + - pycodestyle=2.4.0=py36_0 + - pycosat=0.6.3=py36h0a5515d_0 + - pycparser=2.19=py36_0 + - pycrypto=2.6.1=py36h14c3975_8 + - pycurl=7.43.0.1=py36hb7f436b_0 + - pyflakes=1.6.0=py36h7bd6a15_0 + - pygments=2.2.0=py36h0d3125c_0 + - pylint=1.8.4=py36_0 + - pyodbc=4.0.23=py36hf484d3e_0 + - pyopengl=3.1.1a1=py36_0 + - pyopenssl=18.0.0=py36_0 + - pyparsing=2.2.0=py36hee85983_1 + - pyqt=5.9.2=py36h751905a_0 + - pyserial=3.4=py36_0 + - pysocks=1.6.8=py36_0 + - pytables=3.4.3=py36h02b9ad4_2 + - pytest=3.5.1=py36_0 + - pytest-arraydiff=0.2=py36_0 + - pytest-astropy=0.3.0=py36_0 + - pytest-doctestplus=0.1.3=py36_0 + - pytest-openfiles=0.3.0=py36_0 + - pytest-remotedata=0.2.1=py36_0 + - python=3.6.9=h265db76_0 + - python-dateutil=2.7.3=py36_0 + - pytz=2018.4=py36_0 + - pywavelets=0.5.2=py36he602eb0_0 + - pyyaml=3.12=py36hafb9ca4_1 + - pyzmq=17.0.0=py36h14c3975_0 + - qt=5.9.5=h7e424d6_0 + - qtawesome=0.4.4=py36h609ed8c_0 + - qtconsole=4.3.1=py36h8f73b5b_0 + - qtpy=1.4.1=py36_0 + - readline=7.0=h7b6447c_5 + - requests=2.18.4=py36he2e5f8d_1 + - rope=0.10.7=py36h147e2ec_0 + - ruamel_yaml=0.15.35=py36h14c3975_1 + - scikit-image=0.13.1=py36h14c3975_1 + - scikit-learn=0.19.1=py36h7aa7ec6_0 + - scipy=1.1.0=py36hfc37229_0 + - seaborn=0.8.1=py36hfad7ec4_0 + - send2trash=1.5.0=py36_0 + - setuptools=41.0.1=py36_0 + - simplegeneric=0.8.1=py36_2 + - singledispatch=3.4.0.3=py36h7a266c3_0 + - sip=4.19.8=py36hf484d3e_0 + - six=1.11.0=py36h372c433_1 + - snappy=1.1.7=hbae5bb6_3 + - snowballstemmer=1.2.1=py36h6febd40_0 + - sortedcollections=0.6.1=py36_0 + - sortedcontainers=1.5.10=py36_0 + - sphinx=1.7.4=py36_0 + - sphinxcontrib=1.0=py36h6d0f590_1 + - sphinxcontrib-websupport=1.0.1=py36hb5cb234_1 + - spyder=3.2.8=py36_0 + - sqlalchemy=1.2.7=py36h6b74fdf_0 + - sqlite=3.29.0=h7b6447c_0 + - statsmodels=0.9.0=py36h3010b51_0 + - sympy=1.1.1=py36hc6d1c1c_0 + - tblib=1.3.2=py36h34cf8b6_0 + - terminado=0.8.1=py36_1 + - testpath=0.3.1=py36h8cadb63_0 + - tk=8.6.8=hbc83047_0 + - toolz=0.9.0=py36_0 + - tornado=5.0.2=py36_0 + - traitlets=4.3.2=py36h674d592_0 + - typing=3.6.4=py36_0 + - unicodecsv=0.14.1=py36ha668878_0 + - unixodbc=2.3.6=h1bed415_0 + - urllib3=1.22=py36hbe7ace6_0 + - wcwidth=0.1.7=py36hdf4376a_0 + - webencodings=0.5.1=py36h800622e_1 + - werkzeug=0.14.1=py36_0 + - wheel=0.33.4=py36_0 + - widgetsnbextension=3.2.1=py36_0 + - wrapt=1.10.11=py36h28b7045_0 + - x264=1!152.20180717=h470a237_1 + - xlrd=1.1.0=py36h1db9f0c_1 + - xlsxwriter=1.0.4=py36_0 + - xlwt=1.3.0=py36h7b00a1f_0 + - xz=5.2.4=h14c3975_4 + - yaml=0.1.7=had09818_2 + - zeromq=4.2.5=h439df22_0 + - zict=0.1.3=py36h3a3bf81_0 + - zlib=1.2.11=h7b6447c_3 + - pip: + - deepdish==0.3.6 + - dlib==19.16.0 + - enum34==1.1.6 + - ffprobe==0.5 + - future==0.17.1 + - imutils==0.5.1 + - iso8601==0.1.12 + - jupyter-http-over-ws==0.0.6 + - open-3d==0.3.0.0 + - open3d-official==0.3.0.0 + - pims==0.4.1 + - pip==19.0.3 + - pkg-config==0.0.1 + - plyfile==0.7 + - poppy==0.8.0 + - pptk==0.1.0 + - py-tvd==1.0 + - pybind11==2.2.4 + - python-pptx==0.6.17 + - scikit-video==1.1.10 + - simpleitk==1.2.2 + - slicerator==0.9.8 + - torch==1.0.1 + - torchsummary==1.5.1 + - torchvision==0.4.0 + - tqdm==4.35.0 +prefix: /home/aaa/anaconda3 + diff --git a/models.py b/models.py new file mode 100644 index 0000000..17964d7 --- /dev/null +++ b/models.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Sun Sep 8 18:50:11 2019 + +@author: manoj +""" + + +from densenet import DenseNet2D +model_dict = {} + +model_dict['densenet'] = DenseNet2D(dropout=True,prob=0.2) diff --git a/opt.py b/opt.py new file mode 100644 index 0000000..6c8de2c --- /dev/null +++ b/opt.py @@ -0,0 +1,40 @@ +from pprint import pprint +import argparse + +def parse_args(): + + parser = argparse.ArgumentParser() + # Data input settings + parser.add_argument('--dataset', type=str, default='Semantic_Segmentation_Dataset/', help='name of dataset') + # Optimization: General + parser.add_argument('--bs', type=int, default = 8 ) + parser.add_argument('--epochs', type=int,help='Number of epochs',default= 125) + parser.add_argument('--workers', type=int,help='Number of workers',default=4) + parser.add_argument('--model', help='model name',default='densenet') + parser.add_argument('--evalsplit', help='eval spolit',default='val') + parser.add_argument('--lr', type=float,default= 1e-3,help='Learning rate') + parser.add_argument('--save', help='save folder name',default='0try') + parser.add_argument('--seed', type=int, default=1111, help='random seed') + parser.add_argument('--load', type=str, default=None, help='load checkpoint file name') + parser.add_argument('--resume', action='store_true', help='resume train from load chkpoint') + parser.add_argument('--test', action='store_true', help='test only') + parser.add_argument('--savemodel',action='store_true',help='checkpoint save the model') + parser.add_argument('--testrun', action='store_true', help='test run with few dataset') + parser.add_argument('--expname', type=str, default='info', help='extra explanation of the method') + parser.add_argument('--useGPU', type=str, default=True, help='Set it as False if GPU is unavailable') + + # parse + args = parser.parse_args() + opt = vars(args) + pprint('parsed input parameters:') + pprint(opt) + return args + +if __name__ == '__main__': + + opt = parse_args() + print('opt[\'dataset\'] is ', opt.dataset) + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ab189de --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torchsummary +tqdm +matplotlib +numpy +torch +PIL +cv2 +argparse +pprint diff --git a/starburst_black.png b/starburst_black.png new file mode 100644 index 0000000..7d7c0e4 Binary files /dev/null and b/starburst_black.png differ diff --git a/test.py b/test.py new file mode 100644 index 0000000..0a8ca96 --- /dev/null +++ b/test.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 2 11:37:59 2019 + +@author: aaa +""" +import torch +from dataset import IrisDataset +from torch.utils.data import DataLoader +import numpy as np +import matplotlib.pyplot as plt +from dataset import transform +import os +from opt import parse_args +from models import model_dict +from tqdm import tqdm +from utils import get_predictions +#%% + +if __name__ == '__main__': + + args = parse_args() + + if args.model not in model_dict: + print ("Model not found !!!") + print ("valid models are:",list(model_dict.keys())) + exit(1) + + if args.useGPU: + device=torch.device("cuda") + else: + device=torch.device("cpu") + + model = model_dict[args.model] + model = model.to(device) + filename = args.load + if not os.path.exists(filename): + print("model path not found !!!") + exit(1) + + model.load_state_dict(torch.load(filename)) + model = model.to(device) + model.eval() + + test_set = IrisDataset(filepath = 'Semantic_Segmentation_Dataset/',\ + split = 'test',transform = transform) + + testloader = DataLoader(test_set, batch_size = args.bs, + shuffle=False, num_workers=2) + counter=0 + + os.makedirs('test/labels/',exist_ok=True) + os.makedirs('test/output/',exist_ok=True) + os.makedirs('test/mask/',exist_ok=True) + + with torch.no_grad(): + for i, batchdata in tqdm(enumerate(testloader),total=len(testloader)): + img,labels,index,x,y= batchdata + data = img.to(device) + output = model(data) + predict = get_predictions(output) + for j in range (len(index)): + np.save('test/labels/{}.npy'.format(index[j]),predict[j].cpu().numpy()) + try: + plt.imsave('test/output/{}.jpg'.format(index[j]),255*labels[j].cpu().numpy()) + except: + pass + + pred_img = predict[j].cpu().numpy()/3.0 + inp = img[j].squeeze() * 0.5 + 0.5 + img_orig = np.clip(inp,0,1) + img_orig = np.array(img_orig) + combine = np.hstack([img_orig,pred_img]) + plt.imsave('test/mask/{}.jpg'.format(index[j]),combine) + + os.rename('test',args.save) diff --git a/test/labels/000000337021.npy b/test/labels/000000337021.npy new file mode 100644 index 0000000..4e23d05 Binary files /dev/null and b/test/labels/000000337021.npy differ diff --git a/test/labels/000000337043.npy b/test/labels/000000337043.npy new file mode 100644 index 0000000..2c9aa6e Binary files /dev/null and b/test/labels/000000337043.npy differ diff --git a/test/labels/000000337047.npy b/test/labels/000000337047.npy new file mode 100644 index 0000000..236d1e7 Binary files /dev/null and b/test/labels/000000337047.npy differ diff --git a/test/labels/000000337720.npy b/test/labels/000000337720.npy new file mode 100644 index 0000000..cdf97f1 Binary files /dev/null and b/test/labels/000000337720.npy differ diff --git a/test/labels/000000337733.npy b/test/labels/000000337733.npy new file mode 100644 index 0000000..1168634 Binary files /dev/null and b/test/labels/000000337733.npy differ diff --git a/test/labels/000000337735.npy b/test/labels/000000337735.npy new file mode 100644 index 0000000..d00bd95 Binary files /dev/null and b/test/labels/000000337735.npy differ diff --git a/test/labels/000000338387.npy b/test/labels/000000338387.npy new file mode 100644 index 0000000..c4ea0f3 Binary files /dev/null and b/test/labels/000000338387.npy differ diff --git a/test/labels/000000339828.npy b/test/labels/000000339828.npy new file mode 100644 index 0000000..a95d66a Binary files /dev/null and b/test/labels/000000339828.npy differ diff --git a/test/labels/000000339839.npy b/test/labels/000000339839.npy new file mode 100644 index 0000000..879e24f Binary files /dev/null and b/test/labels/000000339839.npy differ diff --git a/test/labels/000000339848.npy b/test/labels/000000339848.npy new file mode 100644 index 0000000..583cd33 Binary files /dev/null and b/test/labels/000000339848.npy differ diff --git a/test/labels/000000339850.npy b/test/labels/000000339850.npy new file mode 100644 index 0000000..988d4fe Binary files /dev/null and b/test/labels/000000339850.npy differ diff --git a/test/labels/000000340534.npy b/test/labels/000000340534.npy new file mode 100644 index 0000000..7ac4fd5 Binary files /dev/null and b/test/labels/000000340534.npy differ diff --git a/test/labels/000000340542.npy b/test/labels/000000340542.npy new file mode 100644 index 0000000..a472a8a Binary files /dev/null and b/test/labels/000000340542.npy differ diff --git a/test/labels/000000341973.npy b/test/labels/000000341973.npy new file mode 100644 index 0000000..89d9572 Binary files /dev/null and b/test/labels/000000341973.npy differ diff --git a/test/labels/000000342664.npy b/test/labels/000000342664.npy new file mode 100644 index 0000000..cd7cdf2 Binary files /dev/null and b/test/labels/000000342664.npy differ diff --git a/test/labels/000000343356.npy b/test/labels/000000343356.npy new file mode 100644 index 0000000..b9e4e68 Binary files /dev/null and b/test/labels/000000343356.npy differ diff --git a/test/labels/000000344729.npy b/test/labels/000000344729.npy new file mode 100644 index 0000000..4bfa41c Binary files /dev/null and b/test/labels/000000344729.npy differ diff --git a/test/labels/000000344746.npy b/test/labels/000000344746.npy new file mode 100644 index 0000000..ca9086d Binary files /dev/null and b/test/labels/000000344746.npy differ diff --git a/test/labels/000000345428.npy b/test/labels/000000345428.npy new file mode 100644 index 0000000..14c39d0 Binary files /dev/null and b/test/labels/000000345428.npy differ diff --git a/test/labels/000000346092.npy b/test/labels/000000346092.npy new file mode 100644 index 0000000..3f267a8 Binary files /dev/null and b/test/labels/000000346092.npy differ diff --git a/test/labels/000000346102.npy b/test/labels/000000346102.npy new file mode 100644 index 0000000..045cab5 Binary files /dev/null and b/test/labels/000000346102.npy differ diff --git a/test/labels/000000346107.npy b/test/labels/000000346107.npy new file mode 100644 index 0000000..3837fc8 Binary files /dev/null and b/test/labels/000000346107.npy differ diff --git a/test/labels/000000346784.npy b/test/labels/000000346784.npy new file mode 100644 index 0000000..f63cb8d Binary files /dev/null and b/test/labels/000000346784.npy differ diff --git a/test/labels/000000347467.npy b/test/labels/000000347467.npy new file mode 100644 index 0000000..3d8a92e Binary files /dev/null and b/test/labels/000000347467.npy differ diff --git a/test/labels/000000347488.npy b/test/labels/000000347488.npy new file mode 100644 index 0000000..86332ea Binary files /dev/null and b/test/labels/000000347488.npy differ diff --git a/test/labels/000000348174.npy b/test/labels/000000348174.npy new file mode 100644 index 0000000..77d414a Binary files /dev/null and b/test/labels/000000348174.npy differ diff --git a/test/labels/000000348182.npy b/test/labels/000000348182.npy new file mode 100644 index 0000000..97d426e Binary files /dev/null and b/test/labels/000000348182.npy differ diff --git a/test/labels/000000348212.npy b/test/labels/000000348212.npy new file mode 100644 index 0000000..1107a73 Binary files /dev/null and b/test/labels/000000348212.npy differ diff --git a/test/labels/000000348897.npy b/test/labels/000000348897.npy new file mode 100644 index 0000000..247334b Binary files /dev/null and b/test/labels/000000348897.npy differ diff --git a/test/labels/000000349618.npy b/test/labels/000000349618.npy new file mode 100644 index 0000000..f07824a Binary files /dev/null and b/test/labels/000000349618.npy differ diff --git a/test/labels/000000350263.npy b/test/labels/000000350263.npy new file mode 100644 index 0000000..530273e Binary files /dev/null and b/test/labels/000000350263.npy differ diff --git a/test/labels/000000350293.npy b/test/labels/000000350293.npy new file mode 100644 index 0000000..8476d6a Binary files /dev/null and b/test/labels/000000350293.npy differ diff --git a/test/labels/000000352383.npy b/test/labels/000000352383.npy new file mode 100644 index 0000000..d7084b7 Binary files /dev/null and b/test/labels/000000352383.npy differ diff --git a/test/labels/000000352385.npy b/test/labels/000000352385.npy new file mode 100644 index 0000000..3216b16 Binary files /dev/null and b/test/labels/000000352385.npy differ diff --git a/test/labels/000000352398.npy b/test/labels/000000352398.npy new file mode 100644 index 0000000..df9983d Binary files /dev/null and b/test/labels/000000352398.npy differ diff --git a/test/labels/000000352410.npy b/test/labels/000000352410.npy new file mode 100644 index 0000000..f820d9a Binary files /dev/null and b/test/labels/000000352410.npy differ diff --git a/test/labels/000000353117.npy b/test/labels/000000353117.npy new file mode 100644 index 0000000..09cca81 Binary files /dev/null and b/test/labels/000000353117.npy differ diff --git a/test/labels/000000353820.npy b/test/labels/000000353820.npy new file mode 100644 index 0000000..0ba8d2b Binary files /dev/null and b/test/labels/000000353820.npy differ diff --git a/test/labels/000000355220.npy b/test/labels/000000355220.npy new file mode 100644 index 0000000..21017c0 Binary files /dev/null and b/test/labels/000000355220.npy differ diff --git a/test/labels/000000355243.npy b/test/labels/000000355243.npy new file mode 100644 index 0000000..9d6df4b Binary files /dev/null and b/test/labels/000000355243.npy differ diff --git a/test/mask/000000337021.jpg b/test/mask/000000337021.jpg new file mode 100644 index 0000000..cd903f0 Binary files /dev/null and b/test/mask/000000337021.jpg differ diff --git a/test/mask/000000337043.jpg b/test/mask/000000337043.jpg new file mode 100644 index 0000000..51f000b Binary files /dev/null and b/test/mask/000000337043.jpg differ diff --git a/test/mask/000000337047.jpg b/test/mask/000000337047.jpg new file mode 100644 index 0000000..17ca205 Binary files /dev/null and b/test/mask/000000337047.jpg differ diff --git a/test/mask/000000337720.jpg b/test/mask/000000337720.jpg new file mode 100644 index 0000000..2de7c0a Binary files /dev/null and b/test/mask/000000337720.jpg differ diff --git a/test/mask/000000337733.jpg b/test/mask/000000337733.jpg new file mode 100644 index 0000000..ae4fdbd Binary files /dev/null and b/test/mask/000000337733.jpg differ diff --git a/test/mask/000000337735.jpg b/test/mask/000000337735.jpg new file mode 100644 index 0000000..30f92cc Binary files /dev/null and b/test/mask/000000337735.jpg differ diff --git a/test/mask/000000338387.jpg b/test/mask/000000338387.jpg new file mode 100644 index 0000000..cae67ad Binary files /dev/null and b/test/mask/000000338387.jpg differ diff --git a/test/mask/000000339828.jpg b/test/mask/000000339828.jpg new file mode 100644 index 0000000..c5270cf Binary files /dev/null and b/test/mask/000000339828.jpg differ diff --git a/test/mask/000000339839.jpg b/test/mask/000000339839.jpg new file mode 100644 index 0000000..bb6bb95 Binary files /dev/null and b/test/mask/000000339839.jpg differ diff --git a/test/mask/000000339848.jpg b/test/mask/000000339848.jpg new file mode 100644 index 0000000..5592886 Binary files /dev/null and b/test/mask/000000339848.jpg differ diff --git a/test/mask/000000339850.jpg b/test/mask/000000339850.jpg new file mode 100644 index 0000000..82bf7c4 Binary files /dev/null and b/test/mask/000000339850.jpg differ diff --git a/test/mask/000000340534.jpg b/test/mask/000000340534.jpg new file mode 100644 index 0000000..524c7d7 Binary files /dev/null and b/test/mask/000000340534.jpg differ diff --git a/test/mask/000000340542.jpg b/test/mask/000000340542.jpg new file mode 100644 index 0000000..c79d758 Binary files /dev/null and b/test/mask/000000340542.jpg differ diff --git a/test/mask/000000341973.jpg b/test/mask/000000341973.jpg new file mode 100644 index 0000000..6c87790 Binary files /dev/null and b/test/mask/000000341973.jpg differ diff --git a/test/mask/000000342664.jpg b/test/mask/000000342664.jpg new file mode 100644 index 0000000..4f047e1 Binary files /dev/null and b/test/mask/000000342664.jpg differ diff --git a/test/mask/000000343356.jpg b/test/mask/000000343356.jpg new file mode 100644 index 0000000..85e532b Binary files /dev/null and b/test/mask/000000343356.jpg differ diff --git a/test/mask/000000344729.jpg b/test/mask/000000344729.jpg new file mode 100644 index 0000000..c46d332 Binary files /dev/null and b/test/mask/000000344729.jpg differ diff --git a/test/mask/000000344746.jpg b/test/mask/000000344746.jpg new file mode 100644 index 0000000..4b38122 Binary files /dev/null and b/test/mask/000000344746.jpg differ diff --git a/test/mask/000000345428.jpg b/test/mask/000000345428.jpg new file mode 100644 index 0000000..2062b72 Binary files /dev/null and b/test/mask/000000345428.jpg differ diff --git a/test/mask/000000346092.jpg b/test/mask/000000346092.jpg new file mode 100644 index 0000000..c50b65a Binary files /dev/null and b/test/mask/000000346092.jpg differ diff --git a/test/mask/000000346102.jpg b/test/mask/000000346102.jpg new file mode 100644 index 0000000..514e172 Binary files /dev/null and b/test/mask/000000346102.jpg differ diff --git a/test/mask/000000346107.jpg b/test/mask/000000346107.jpg new file mode 100644 index 0000000..6b3466f Binary files /dev/null and b/test/mask/000000346107.jpg differ diff --git a/test/mask/000000346784.jpg b/test/mask/000000346784.jpg new file mode 100644 index 0000000..b18fd22 Binary files /dev/null and b/test/mask/000000346784.jpg differ diff --git a/test/mask/000000347467.jpg b/test/mask/000000347467.jpg new file mode 100644 index 0000000..d864226 Binary files /dev/null and b/test/mask/000000347467.jpg differ diff --git a/test/mask/000000347488.jpg b/test/mask/000000347488.jpg new file mode 100644 index 0000000..8f32d24 Binary files /dev/null and b/test/mask/000000347488.jpg differ diff --git a/test/mask/000000348174.jpg b/test/mask/000000348174.jpg new file mode 100644 index 0000000..2b6c39d Binary files /dev/null and b/test/mask/000000348174.jpg differ diff --git a/test/mask/000000348182.jpg b/test/mask/000000348182.jpg new file mode 100644 index 0000000..e6881f5 Binary files /dev/null and b/test/mask/000000348182.jpg differ diff --git a/test/mask/000000348212.jpg b/test/mask/000000348212.jpg new file mode 100644 index 0000000..aa9ff42 Binary files /dev/null and b/test/mask/000000348212.jpg differ diff --git a/test/mask/000000348897.jpg b/test/mask/000000348897.jpg new file mode 100644 index 0000000..ebf3b5f Binary files /dev/null and b/test/mask/000000348897.jpg differ diff --git a/test/mask/000000349618.jpg b/test/mask/000000349618.jpg new file mode 100644 index 0000000..bcfc9a3 Binary files /dev/null and b/test/mask/000000349618.jpg differ diff --git a/test/mask/000000350263.jpg b/test/mask/000000350263.jpg new file mode 100644 index 0000000..66214e1 Binary files /dev/null and b/test/mask/000000350263.jpg differ diff --git a/test/mask/000000350293.jpg b/test/mask/000000350293.jpg new file mode 100644 index 0000000..0a2e8ac Binary files /dev/null and b/test/mask/000000350293.jpg differ diff --git a/test/mask/000000352383.jpg b/test/mask/000000352383.jpg new file mode 100644 index 0000000..034cfae Binary files /dev/null and b/test/mask/000000352383.jpg differ diff --git a/test/mask/000000352385.jpg b/test/mask/000000352385.jpg new file mode 100644 index 0000000..e5e4fb4 Binary files /dev/null and b/test/mask/000000352385.jpg differ diff --git a/test/mask/000000352398.jpg b/test/mask/000000352398.jpg new file mode 100644 index 0000000..58f5852 Binary files /dev/null and b/test/mask/000000352398.jpg differ diff --git a/test/mask/000000352410.jpg b/test/mask/000000352410.jpg new file mode 100644 index 0000000..dba18b6 Binary files /dev/null and b/test/mask/000000352410.jpg differ diff --git a/test/mask/000000353117.jpg b/test/mask/000000353117.jpg new file mode 100644 index 0000000..0c4842b Binary files /dev/null and b/test/mask/000000353117.jpg differ diff --git a/test/mask/000000353820.jpg b/test/mask/000000353820.jpg new file mode 100644 index 0000000..b9abfcb Binary files /dev/null and b/test/mask/000000353820.jpg differ diff --git a/test/mask/000000355220.jpg b/test/mask/000000355220.jpg new file mode 100644 index 0000000..3913699 Binary files /dev/null and b/test/mask/000000355220.jpg differ diff --git a/test/mask/000000355243.jpg b/test/mask/000000355243.jpg new file mode 100644 index 0000000..223cd14 Binary files /dev/null and b/test/mask/000000355243.jpg differ diff --git a/train.py b/train.py new file mode 100644 index 0000000..6bd5be3 --- /dev/null +++ b/train.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 2 11:22:32 2019 + +@author: aayush +""" + +from models import model_dict +from torch.utils.data import DataLoader +from dataset import IrisDataset +import torch +from utils import mIoU, CrossEntropyLoss2d,total_metric,get_nparams,Logger,GeneralizedDiceLoss,SurfaceLoss +import numpy as np +from dataset import transform +from opt import parse_args +import os +from utils import get_predictions +from tqdm import tqdm +import matplotlib.pyplot as plt +#%% + +def lossandaccuracy(loader,model,factor): + epoch_loss = [] + ious = [] + model.eval() + with torch.no_grad(): + for i, batchdata in enumerate(loader): +# print (len(batchdata)) + img,labels,index,spatialWeights,maxDist=batchdata + data = img.to(device) + + target = labels.to(device).long() + output = model(data) + + ## loss from cross entropy is weighted sum of pixel wise loss and Canny edge loss *20 + CE_loss = criterion(output,target) + loss = CE_loss*(torch.from_numpy(np.ones(spatialWeights.shape)).to(torch.float32).to(device)()+(spatialWeights).to(torch.float32).to(device)()) + + loss=torch.mean(loss).to(torch.float32).to(device) + loss_dice = criterion_DICE(output,target) + loss_sl = torch.mean(criterion_SL(output.to(device),(maxDist).to(device))) + + ##total loss is the weighted sum of suface loss and dice loss plus the boundary weighted cross entropy loss + loss = (1-factor)*loss_sl+factor*(loss_dice)+loss + + epoch_loss.append(loss.item()) + predict = get_predictions(output) + iou = mIoU(predict,labels) + ious.append(iou) + return np.average(epoch_loss),np.average(ious) + +#%% +if __name__ == '__main__': + + args = parse_args() + kwargs = vars(args) + +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.useGPU: + device=torch.device("cuda") + torch.cuda.manual_seed(12) + else: + device=torch.device("cpu") + torch.manual_seed(12) + + torch.backends.cudnn.deterministic=False + + if args.model not in model_dict: + print ("Model not found !!!") + print ("valid models are:",list(model_dict.keys())) + exit(1) + + LOGDIR = 'logs/{}'.format(args.expname) + os.makedirs(LOGDIR,exist_ok=True) + os.makedirs(LOGDIR+'/models',exist_ok=True) + logger = Logger(os.path.join(LOGDIR,'logs.log')) + + model = model_dict[args.model] + model = model.to(device) + torch.save(model.state_dict(), '{}/models/dense_net{}.pkl'.format(LOGDIR,'_0')) + model.train() + nparams = get_nparams(model) + + try: + from torchsummary import summary + summary(model,input_size=(1,640,400)) + print("Max params:", 1024*1024/4.0) + logger.write_summary(str(model.parameters)) + except: + print ("Torch summary not found !!!") + + optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=5) + + criterion = CrossEntropyLoss2d() + criterion_DICE = GeneralizedDiceLoss(softmax=True, reduction=True) + criterion_SL = SurfaceLoss() + + Path2file = args.dataset + train = IrisDataset(filepath = Path2file,split='train', + transform = transform, **kwargs) + + valid = IrisDataset(filepath = Path2file , split='validation', + transform = transform, **kwargs) + + trainloader = DataLoader(train, batch_size = args.bs, + shuffle=True, num_workers = args.workers) + + validloader = DataLoader(valid, batch_size = args.bs, + shuffle= False, num_workers = args.workers) + + test = IrisDataset(filepath = Path2file , split='test', + transform = transform, **kwargs) + + testloader = DataLoader(test, batch_size = args.bs, + shuffle=False, num_workers = args.workers) + + +# alpha = 1 - np.arange(1,args.epochs)/args.epoch + ##The weighing function for the dice loss and surface loss + alpha=np.zeros(((args.epochs))) + alpha[0:np.min([125,args.epochs])]=1 - np.arange(1,np.min([125,args.epochs])+1)/np.min([125,args.epochs]) + if args.epochs>125: + alpha[125:]=1 + ious = [] + for epoch in range(args.epochs): + for i, batchdata in enumerate(trainloader): +# print (len(batchdata)) + img,labels,index,spatialWeights,maxDist= batchdata + data = img.to(device) + target = labels.to(device).long() + optimizer.zero_grad() + output = model(data) + ## loss from cross entropy is weighted sum of pixel wise loss and Canny edge loss *20 + CE_loss = criterion(output,target) + loss = CE_loss*(torch.from_numpy(np.ones(spatialWeights.shape)).to(torch.float32).to(device)+(spatialWeights).to(torch.float32).to(device)) + + loss=torch.mean(loss).to(torch.float32).to(device) + loss_dice = criterion_DICE(output,target) + loss_sl = torch.mean(criterion_SL(output.to(device),(maxDist).to(device))) + + ##total loss is the weighted sum of suface loss and dice loss plus the boundary weighted cross entropy loss + loss = (1-alpha[epoch])*loss_sl+alpha[epoch]*(loss_dice)+loss +# + predict = get_predictions(output) + iou = mIoU(predict,labels) + ious.append(iou) + + if i%10 == 0: + logger.write('Epoch:{} [{}/{}], Loss: {:.3f}'.format(epoch,i,len(trainloader),loss.item())) + + loss.backward() + optimizer.step() + + logger.write('Epoch:{}, Train mIoU: {}'.format(epoch,np.average(ious))) + lossvalid , miou = lossandaccuracy(validloader,model,alpha[epoch]) + totalperf = total_metric(nparams,miou) + f = 'Epoch:{}, Valid Loss: {:.3f} mIoU: {} Complexity: {} total: {}' + logger.write(f.format(epoch,lossvalid, miou,nparams,totalperf)) + + scheduler.step(lossvalid) + + ##save the model every epoch + if epoch %1 == 0: + torch.save(model.state_dict(), '{}/models/dense_net{}.pkl'.format(LOGDIR,epoch)) + + ##visualize the ouput every 5 epoch + if epoch %5 ==0: + os.makedirs('test/epoch/labels/',exist_ok=True) + os.makedirs('test/epoch/output/',exist_ok=True) + os.makedirs('test/epoch/mask/',exist_ok=True) + + with torch.no_grad(): + for i, batchdata in tqdm(enumerate(testloader),total=len(testloader)): + img,labels,index,x,maxDist= batchdata + data = img.to(device) + output = model(data) + predict = get_predictions(output) + for j in range (len(index)): + np.save('test/epoch/labels/{}.npy'.format(index[j]),predict[j].cpu().numpy()) + try: + plt.imsave('test/epoch/output/{}.jpg'.format(index[j]),255*labels[j].cpu().numpy()) + except: + pass + pred_img = predict[j].cpu().numpy()/3.0 + inp = img[j].squeeze() * 0.5 + 0.5 + img_orig = np.clip(inp,0,1) + img_orig = np.array(img_orig) + combine = np.hstack([img_orig,pred_img]) + plt.imsave('test/epoch/mask/{}.jpg'.format(index[j]),combine) + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..3c4d6b7 --- /dev/null +++ b/utils.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 27 16:04:18 2019 + +@author: Aayush Chaudhary +""" +##https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import cv2 +import os + +from sklearn.metrics import precision_score , recall_score,f1_score +from scipy.ndimage import distance_transform_edt as distance +#%% +class FocalLoss2d(nn.Module): + def __init__(self, weight=None,gamma=2): + super(FocalLoss2d,self).__init__() + self.gamma = gamma + self.loss = nn.NLLLoss(weight) + def forward(self, outputs, targets): + return self.loss((1 - nn.Softmax2d()(outputs)).pow(self.gamma) * torch.log(nn.Softmax2d()(outputs)), targets) + +###https://github.com/ycszen/pytorch-segmentation/blob/master/loss.py +# https://discuss.pytorch.org/t/using-cross-entropy-loss-with-semantic-segmentation-model/31988 +class CrossEntropyLoss2d(nn.Module): + + def __init__(self, weight=None): + super(CrossEntropyLoss2d,self).__init__() + self.loss = nn.NLLLoss(weight) + + def forward(self, outputs, targets): + return self.loss(F.log_softmax(outputs,dim=1), targets) + +class SurfaceLoss(nn.Module): + # Author: Rakshit Kothari + def __init__(self, epsilon=1e-5, softmax=True): + super(SurfaceLoss, self).__init__() + self.weight_map = [] + def forward(self, x, distmap): + x = torch.softmax(x, dim=1) + self.weight_map = distmap + score = x.flatten(start_dim=2)*distmap.flatten(start_dim=2) + score = torch.mean(score, dim=2) # Mean between pixels per channel + score = torch.mean(score, dim=1) # Mean between channels + return score + + +class GeneralizedDiceLoss(nn.Module): + # Author: Rakshit Kothari + # Input: (B, C, ...) + # Target: (B, C, ...) + def __init__(self, epsilon=1e-5, weight=None, softmax=True, reduction=True): + super(GeneralizedDiceLoss, self).__init__() + self.epsilon = epsilon + self.weight = [] + self.reduction = reduction + if softmax: + self.norm = nn.Softmax(dim=1) + else: + self.norm = nn.Sigmoid() + + def forward(self, ip, target): + + # Rapid way to convert to one-hot. For future version, use functional + Label = (np.arange(4) == target.cpu().numpy()[..., None]).astype(np.uint8) + target = torch.from_numpy(np.rollaxis(Label, 3,start=1)).cuda() + + assert ip.shape == target.shape + ip = self.norm(ip) + + # Flatten for multidimensional data + ip = torch.flatten(ip, start_dim=2, end_dim=-1).cuda().to(torch.float32) + target = torch.flatten(target, start_dim=2, end_dim=-1).cuda().to(torch.float32) + + numerator = ip*target + denominator = ip + target + + class_weights = 1./(torch.sum(target, dim=2)**2).clamp(min=self.epsilon) + + A = class_weights*torch.sum(numerator, dim=2) + B = class_weights*torch.sum(denominator, dim=2) + + dice_metric = 2.*torch.sum(A, dim=1)/torch.sum(B, dim=1) + if self.reduction: + return torch.mean(1. - dice_metric.clamp(min=self.epsilon)) + else: + return 1. - dice_metric.clamp(min=self.epsilon) + +def one_hot2dist(posmask): + # Input: Mask. Will be converted to Bool. + # Author: Rakshit Kothari + assert len(posmask.shape) == 2 + h, w = posmask.shape + res = np.zeros_like(posmask) + posmask = posmask.astype(np.bool) + mxDist = np.sqrt((h-1)**2 + (w-1)**2) + if posmask.any(): + negmask = ~posmask + res = distance(negmask) * negmask - (distance(posmask) - 1) * posmask + return res/mxDist + +def mIoU(predictions, targets,info=False): ###Mean per class accuracy + unique_labels = np.unique(targets) + num_unique_labels = len(unique_labels) + ious = [] + for index in range(num_unique_labels): + pred_i = predictions == index + label_i = targets == index + intersection = np.logical_and(label_i, pred_i) + union = np.logical_or(label_i, pred_i) + iou_score = np.sum(intersection.numpy())/np.sum(union.numpy()) + ious.append(iou_score) + if info: + print ("per-class mIOU: ", ious) + return np.mean(ious) + +#GA: Global Pixel Accuracy +#CA: Mean Class Accuracy for different classes +# +#Back: Background (non-eye part of peri-ocular region) +#Sclera: Sclera +#Iris: Iris +#Pupil: Pupil +#Precision: Computed using sklearn.metrics.precision_score(pred, gt, ‘weighted’) +#Recall: Computed using sklearn.metrics.recall_score(pred, gt, ‘weighted’) +#F1: Computed using sklearn.metrics.f1_score(pred, gt, ‘weighted’) +#IoU: Computed using the function below +def compute_mean_iou(flat_pred, flat_label,info=False): + ''' + compute mean intersection over union (IOU) over all classes + :param flat_pred: flattened prediction matrix + :param flat_label: flattened label matrix + :return: mean IOU + ''' + unique_labels = np.unique(flat_label) + num_unique_labels = len(unique_labels) + + Intersect = np.zeros(num_unique_labels) + Union = np.zeros(num_unique_labels) + precision = np.zeros(num_unique_labels) + recall = np.zeros(num_unique_labels) + f1 = np.zeros(num_unique_labels) + + for index, val in enumerate(unique_labels): + pred_i = flat_pred == val + label_i = flat_label == val + + if info: + precision[index] = precision_score(pred_i, label_i, 'weighted') + recall[index] = recall_score(pred_i, label_i, 'weighted') + f1[index] = f1_score(pred_i, label_i, 'weighted') + + Intersect[index] = float(np.sum(np.logical_and(label_i, pred_i))) + Union[index] = float(np.sum(np.logical_or(label_i, pred_i))) + + if info: + print ("per-class mIOU: ", Intersect / Union) + print ("per-class precision: ", precision) + print ("per-class recall: ", recall) + print ("per-class f1: ", f1) + mean_iou = np.mean(Intersect / Union) + return mean_iou + +def total_metric(nparams,miou): + S = nparams * 4.0 / (1024 * 1024) + total = min(1,1.0/S) + miou + return total * 0.5 + + +def get_nparams(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def get_predictions(output): + bs,c,h,w = output.size() + values, indices = output.cpu().max(1) + indices = indices.view(bs,h,w) # bs x h x w + return indices + + +class Logger(): + def __init__(self, output_name): + dirname = os.path.dirname(output_name) + if not os.path.exists(dirname): + os.mkdir(dirname) + self.dirname = dirname + self.log_file = open(output_name, 'a+') + self.infos = {} + + def append(self, key, val): + vals = self.infos.setdefault(key, []) + vals.append(val) + + def log(self, extra_msg=''): + msgs = [extra_msg] + for key, vals in self.infos.iteritems(): + msgs.append('%s %.6f' % (key, np.mean(vals))) + msg = '\n'.join(msgs) + self.log_file.write(msg + '\n') + self.log_file.flush() + self.infos = {} + return msg + + def write_silent(self, msg): + self.log_file.write(msg + '\n') + self.log_file.flush() + + def write(self, msg): + self.log_file.write(msg + '\n') + self.log_file.flush() + print (msg) + def write_summary(self,msg): + self.log_file.write(msg) + self.log_file.write('\n') + self.log_file.flush() + print (msg) +