-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathteste.py
124 lines (105 loc) · 5.03 KB
/
teste.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import cv2
import sys
import numpy as np
import torch.nn.init
# import matplotlib.pyplot as plt
import random
use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation')
parser.add_argument('--scribble', action='store_true', default=False, help='use scribbles')
parser.add_argument('--nChannel', metavar='N', default=100, type=int, help='number of channels')
parser.add_argument('--maxIter', metavar='T', default=500, type=int, help='number of maximum iterations')
parser.add_argument('--minLabels', metavar='minL', default=3, type=int, help='minimum number of labels')
parser.add_argument('--lr', metavar='LR', default=0.1, type=float, help='learning rate')
parser.add_argument('--nConv', metavar='M', default=3, type=int, help='number of convolutional layers')
parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, help='visualization flag')
# parser.add_argument('--input', metavar='FILENAME', default=r'D:\Users\paulo\PycharmProjects\pytorch-unsupervised-segmentation-tip\imagens\3.png', help='input image file name', required=False)
parser.add_argument('--input', metavar='FILENAME', default=r'D:\Users\paulo\PycharmProjects\pytorch-unsupervised-segmentation-tip\imagens\Normal-CT-head-3Age-30-40_pt.jpg', help='input image file name', required=False)
parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float, help='step size for similarity loss', required=False)
parser.add_argument('--stepsize_con', metavar='CON', default=1, type=float, help='step size for continuity loss')
parser.add_argument('--stepsize_scr', metavar='SCR', default=0.5, type=float, help='step size for scribble loss')
args = parser.parse_args()
# CNN model
class MyNet(nn.Module):
def __init__(self,input_dim):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1 )
self.bn1 = nn.BatchNorm2d(args.nChannel)
self.conv2 = nn.ModuleList()
self.bn2 = nn.ModuleList()
for i in range(args.nConv-1):
self.conv2.append( nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1 ) )
self.bn2.append( nn.BatchNorm2d(args.nChannel) )
self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0 )
self.bn3 = nn.BatchNorm2d(args.nChannel)
def forward(self, x):
x = self.conv1(x)
x = F.relu( x )
x = self.bn1(x)
for i in range(args.nConv-1):
x = self.conv2[i](x)
x = F.relu( x )
x = self.bn2[i](x)
x = self.conv3(x)
x = self.bn3(x)
return x
# load image
im = cv2.imread(args.input)
data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) )
if use_cuda:
data = data.cuda()
data = Variable(data)
# load scribble
if args.scribble:
mask = cv2.imread(args.input.replace('.'+args.input.split('.')[-1],'_scribble.png'),-1)
mask = mask.reshape(-1)
mask_inds = np.unique(mask)
mask_inds = np.delete( mask_inds, np.argwhere(mask_inds==255) )
inds_sim = torch.from_numpy( np.where( mask == 255 )[ 0 ] )
inds_scr = torch.from_numpy( np.where( mask != 255 )[ 0 ] )
target_scr = torch.from_numpy( mask.astype(np.int) )
if use_cuda:
inds_sim = inds_sim.cuda()
inds_scr = inds_scr.cuda()
target_scr = target_scr.cuda()
target_scr = Variable( target_scr )
# set minLabels
args.minLabels = len(mask_inds)
# train
model = MyNet(data.size(1))
# continuity loss definition
loss_hpy = torch.nn.L1Loss(size_average = True)
loss_hpz = torch.nn.L1Loss(size_average = True)
HPy_target = torch.zeros(im.shape[0]-1, im.shape[1], args.nChannel)
HPz_target = torch.zeros(im.shape[0], im.shape[1]-1, args.nChannel)
label_colours = np.random.randint(255,size=(100,3))
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
model.load_state_dict(torch.load(r'D:\Users\paulo\PycharmProjects\pytorch-unsupervised-segmentation-tip\results\model.pth'))
model.eval()
output = model(data)[0]
output = output.permute(1, 2, 0).contiguous().view(-1, args.nChannel)
outputHP = output.reshape( (im.shape[0], im.shape[1], args.nChannel) )
HPy = outputHP[1:, :, :] - outputHP[0:-1, :, :]
HPz = outputHP[:, 1:, :] - outputHP[:, 0:-1, :]
lhpy = loss_hpy(HPy,HPy_target)
lhpz = loss_hpz(HPz,HPz_target)
ignore, target = torch.max( output, 1 )
im_target = target.data.cpu().numpy()
nLabels = len(np.unique(im_target))
if args.visualize:
im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target])
im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 )
im_target_rgb = cv2.resize(im_target_rgb, (600, 600))
cv2.imshow( "output", im_target_rgb )
cv2.imwrite("output.png", im_target_rgb)
cv2.waitKey(10)
a = 45
b = 90