forked from zonechen1994/AI_TLS_segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTest.py
88 lines (73 loc) · 2.94 KB
/
Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
import torch.nn.functional as F
import numpy as np
import os, argparse
os.environ['CUDA_VISIBLE_DEVICES']='1'
from scipy import misc
from utils.dataloader import test_dataset
import cv2
from tqdm import tqdm
from lib.PSCANet_ab import PSCANet
from config import getConfig
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
if __name__ == '__main__':
opt = getConfig()
model = PSCANet(opt)
model_path = (os.path.join(opt.train_save, opt.model, '{}_best_model.pth'.format(opt.refine_channels)))
model.load_state_dict(torch.load(model_path))
model.cuda()
model.eval()
#for _data_name in tqdm(['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']):
##### put data_path here #####
data_path = '../../datasets/TLS_Segmentation/TLS_data/'
stage = opt.stage
if not os.path.exists(os.path.join('result_maps')):
os.makedirs('result_maps')
if not os.path.exists(os.path.join('result_maps', opt.model)):
os.makedirs(os.path.join('result_maps', opt.model))
if not os.path.exists(os.path.join('result_maps', opt.model, stage)):
os.makedirs(os.path.join('result_maps', opt.model, stage))
save_path = './result_maps/{}/{}/'.format(opt.model, stage)
if not os.path.exists(save_path):
os.makedirs(save_path)
file_path = os.path.join(data_path, stage + '.txt')
imgs = []
gts = []
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.strip()
contents = line.split(',')
imgs.append(contents[0])
gts.append(contents[1])
#image_root = '{}/images/'.format(data_path)
#gt_root = '{}/masks/'.format(data_path)
#num1 = len(os.listdir(gt_root))
CropSize = opt.trainsize
img_transforms = transforms.Compose([
transforms.Resize((CropSize, CropSize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])]
)
with torch.no_grad():
nums = len(imgs)
for i in tqdm(range(nums)):
img_path = imgs[i]
dirs, name = img_path.split('/')
if not os.path.exists(os.path.join(save_path, dirs)):
os.makedirs(os.path.join(save_path, dirs))
patch_path = os.path.join(save_path, dirs, name)
img = Image.open(os.path.join(data_path, img_path)).convert('RGB')
h, w = img.size
img_tensor = torch.unsqueeze(img_transforms(img), 0)
inputs = img_tensor.cuda()
P1, P2, P3, P4, P5 = model(inputs)
res = F.upsample(P1, size=(h, w), mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu().numpy().squeeze()
im = Image.fromarray(res*255).convert('RGB')
im.save(patch_path)
print('Finish!!!!!')