-
Notifications
You must be signed in to change notification settings - Fork 21
/
infer.py
112 lines (85 loc) · 3.05 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
"""
@File : infer.py
@Time : 2019/7/1 7:39
@Author : Parker
@Email : [email protected]
@Software: PyCharm
@Des :
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import torch.optim as optim
from tensorboardX import SummaryWriter
import numpy as np
import time
import datetime
import argparse
import os
import os.path as osp
from rs_dataset import RSDataset,InferDataset
from get_logger import get_logger
from networks import Dense201
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument('--test_batch_size',type=int,default=128)
parse.add_argument('--num_workers', type=int, default=8)
parse.add_argument('--data_dir',type=str,default='C:\dataset\\rscup')
parse.add_argument('--model_out_name',type=str,default='./model_out/190810-200202_dense201_pre_warm_de_cos_mix/final_model.pth')
parse.add_argument('--result_name',type=str,default='result.txt')
return parse.parse_args()
def main_worker(args):
data_set = InferDataset(rootpth=args.data_dir)
data_loader = DataLoader(data_set,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
pin_memory=True)
net = Dense201()
net.load_state_dict(torch.load(args.model_out_name))
net.cuda()
net.eval()
with open('classification.txt','w') as f:
with torch.no_grad():
for img,names in data_loader:
img = img.cuda()
size = img.size(0)
outputs = net(img)
outputs = F.softmax(outputs, dim=1)
predicted = torch.max(outputs, dim=1)[1].cpu().numpy()
for i in range(size):
msg = '{} {}'.format(names[i], predicted[i]+1)
f.write(msg)
f.write('\n')
print('----------Done!----------')
def evaluate_val(args):
val_set = RSDataset(rootpth=args.data_dir, mode='val')
val_loader = DataLoader(val_set,
batch_size=args.test_batch_size,
drop_last=True,
shuffle=True,
pin_memory=True,
num_workers=args.num_workers)
net = Dense201()
net.load_state_dict(torch.load(args.model_out_name))
net.cuda()
net.eval()
total = 0
correct = 0
net.eval()
with torch.no_grad():
for img, lb in val_loader:
img, lb = img.cuda(), lb.cuda()
outputs = net(img)
outputs = F.softmax(outputs, dim=1)
predicted = torch.max(outputs, dim=1)[1]
total += lb.size()[0]
correct += (predicted == lb).sum().cpu().item()
print('correct:{}/{}={:.4f}'.format(correct, total, correct * 1. / total))
print('----------Done!----------')
if __name__ == '__main__':
args = parse_args()
main_worker(args)
# evaluate_val(args)