-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
85 lines (75 loc) · 3.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from baseline.datasets import CasiaSurfDataset
from tqdm import tqdm
import models
from torch.utils import data
from torch import nn
from sklearn import metrics
from argparse import ArgumentParser, Namespace
import utils
import numpy as np
import torch
import os
import yaml
from transforms import ValidationTransform
def evaluate(dataloader: data.DataLoader, model: nn.Module, visualize: bool = False):
device = next(model.parameters()).device
model.eval()
print("Evaluating...")
tp, tn, fp, fn = 0, 0, 0, 0
errors = np.array([], dtype=[('img', torch.Tensor),
('label', torch.Tensor), ('prob', float)])
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader)):
images, labels = batch
outputs = model(images.to(device))
outputs = outputs.cpu()
tn_batch, fp_batch, fn_batch, tp_batch = metrics.confusion_matrix(y_true=labels,
y_pred=torch.max(
outputs.data, 1)[1],
labels=[0, 1]).ravel()
if visualize:
errors_idx = np.where(torch.max(outputs.data, 1)[1] != labels)
print(errors_idx)
errors_imgs = list(
zip(images[errors_idx], labels[errors_idx], ))
print(errors_imgs)
errors = np.append(errors, errors_imgs)
tp += tp_batch
tn += tn_batch
fp += fp_batch
fn += fn_batch
apcer = fp / (tn + fp) if fp != 0 else 0
bpcer = fn / (fn + tp) if fn != 0 else 0
acer = (apcer + bpcer) / 2
if visualize:
print(errors)
errors.sort(order='prob')
errors = np.flip(errors)
print(errors)
utils.plot_classes_preds(model, zip(*errors))
return apcer, bpcer, acer
def main(args, config):
model = getattr(models, config['model'])(num_classes=config['num_classes'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(args.checkpoint, map_location=device))
model.eval()
with torch.no_grad():
dataset = CasiaSurfDataset(
args.protocol, mode='dev', dir=args.data_dir, transform=ValidationTransform(), depth=config['depth'], ir=config['ir'])
dataloader = data.DataLoader(
dataset, batch_size=args.batch_size, num_workers=args.num_workers)
apcer, bpcer, acer = evaluate(dataloader, model, args.visualize)
print(f'APCER: {apcer}, BPCER: {bpcer}, ACER: {acer}')
if __name__ == '__main__':
argparser = ArgumentParser()
argparser.add_argument('--protocol', type=int, required=True)
argparser.add_argument('--data-dir', type=str,
default=os.path.join('data', 'CASIA_SURF'))
argparser.add_argument('--checkpoint', type=str, required=True)
argparser.add_argument('--config-path', type=str, required=True)
argparser.add_argument('--batch_size', type=int, default=1)
argparser.add_argument('--visualize', type=bool, default=False)
argparser.add_argument('--num_workers', type=int, default=0)
args = argparser.parse_args()
config = yaml.load(open(args.config_path), Loader=yaml.FullLoader)
main(args, config)