-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
87 lines (73 loc) · 3.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from simplecv import dp_train as train
import torch
from simplecv.util.logger import eval_progress, speed
import time
from module import CLSJE
from simplecv.util import metric
from simplecv.util import registry
from torch.utils.data.dataloader import DataLoader
from simplecv import registry
from simplecv.core.config import AttrDict
from scipy.io import loadmat
import data.dataloader
import time
import matplotlib.pyplot as plt
import os
from sklearn.metrics import confusion_matrix
def fcn_evaluate_fn(self, test_dataloader, config):
strart = time.perf_counter()
if self.checkpoint.global_step < 0:
return
self._model.eval()
total_time = 0.
t=0
with torch.no_grad():
for idx, (im, mask, w) in enumerate(test_dataloader):
start = time.time()
y_pred = self._model(im).squeeze()
torch.cuda.synchronize()
time_cost = round(time.time() - start, 3)
y_pred = y_pred.argmax(dim=0).cpu() + 1
w.unsqueeze_(dim=0)
w = w.byte()
mask = torch.masked_select(mask.view(-1), w.view(-1))
y_pred = torch.masked_select(y_pred.view(-1), w.view(-1))
oa = metric.th_overall_accuracy_score(mask.view(-1), y_pred.view(-1))
confusion_matrix = metric.th_confusion_matrix1(mask.view(-1), y_pred.view(-1), self._model.module.config.num_classes)
# confusion_matrix = metric.confusion_matrix(mask.view(-1), y_pred.view(-1), self._model.module.config.num_classes)
aa, acc_per_class = metric.th_average_accuracy_score(mask.view(-1), y_pred.view(-1),
self._model.module.config.num_classes,
return_accuracys=True)
kappa = metric.th_cohen_kappa_score(mask.view(-1), y_pred.view(-1), self._model.module.config.num_classes)
total_time += time_cost
speed(self._logger, time_cost, 'im')
eval_progress(self._logger, idx + 1, len(test_dataloader))
speed(self._logger, round(total_time / len(test_dataloader), 3), 'batched im (avg)')
metric_dict = {
'OA': oa.item(),
'AA': aa.item(),
'Kappa': kappa.item(),
'Confusion_matrix': confusion_matrix
}
for i, acc in enumerate(acc_per_class):
metric_dict['acc_{}'.format(i + 1)] = acc.item()
self._logger.eval_log(metric_dict=metric_dict, step=self.checkpoint.global_step)
end = time.perf_counter()
print('test time: %s seconds' % (end-strart))
def register_evaluate_fn(launcher):
launcher.override_evaluate(fcn_evaluate_fn)
if __name__ == '__main__':
torch.backends.cudnn.benchmark = True
args = train.parser.parse_args()
SEED = 2333
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
strart = time.perf_counter()
train.run(config_path=args.config_path,
model_dir=args.model_dir,
cpu_mode=args.cpu,
after_construct_launcher_callbacks=[register_evaluate_fn],
# after_construct_launcher_callbacks=[None],
opts=args.opts)
end = time.perf_counter()
print('train time: %s seconds' % (end-strart))