-
Notifications
You must be signed in to change notification settings - Fork 269
/
tutorial_nn.py
99 lines (84 loc) · 3.83 KB
/
tutorial_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Authors: Wouter Van Gansbeke
Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import os
import numpy as np
import torch
from utils.config import create_config
from utils.common_config import get_model, get_train_dataset, \
get_val_dataset, \
get_val_dataloader, \
get_val_transformations \
from utils.memory import MemoryBank
from utils.train_utils import simclr_train
from utils.utils import fill_memory_bank
from termcolor import colored
# Parser
parser = argparse.ArgumentParser(description='Eval_nn')
parser.add_argument('--config_env',
help='Config file for the environment')
parser.add_argument('--config_exp',
help='Config file for the experiment')
args = parser.parse_args()
def main():
# Retrieve config file
p = create_config(args.config_env, args.config_exp)
print(colored(p, 'red'))
# Model
print(colored('Retrieve model', 'blue'))
model = get_model(p)
print('Model is {}'.format(model.__class__.__name__))
print('Model parameters: {:.2f}M'.format(sum(p.numel() for p in model.parameters()) / 1e6))
print(model)
model = model.cuda()
# CUDNN
print(colored('Set CuDNN benchmark', 'blue'))
torch.backends.cudnn.benchmark = True
# Dataset
val_transforms = get_val_transformations(p)
print('Validation transforms:', val_transforms)
val_dataset = get_val_dataset(p, val_transforms)
val_dataloader = get_val_dataloader(p, val_dataset)
print('Dataset contains {} val samples'.format(len(val_dataset)))
# Memory Bank
print(colored('Build MemoryBank', 'blue'))
base_dataset = get_train_dataset(p, val_transforms, split='train') # Dataset w/o augs for knn eval
base_dataloader = get_val_dataloader(p, base_dataset)
memory_bank_base = MemoryBank(len(base_dataset),
p['model_kwargs']['features_dim'],
p['num_classes'], p['criterion_kwargs']['temperature'])
memory_bank_base.cuda()
memory_bank_val = MemoryBank(len(val_dataset),
p['model_kwargs']['features_dim'],
p['num_classes'], p['criterion_kwargs']['temperature'])
memory_bank_val.cuda()
# Checkpoint
assert os.path.exists(p['pretext_checkpoint'])
print(colored('Restart from checkpoint {}'.format(p['pretext_checkpoint']), 'blue'))
checkpoint = torch.load(p['pretext_checkpoint'], map_location='cpu')
model.load_state_dict(checkpoint)
model.cuda()
# Save model
torch.save(model.state_dict(), p['pretext_model'])
# Mine the topk nearest neighbors at the very end (Train)
# These will be served as input to the SCAN loss.
print(colored('Fill memory bank for mining the nearest neighbors (train) ...', 'blue'))
fill_memory_bank(base_dataloader, model, memory_bank_base)
topk = 20
print('Mine the nearest neighbors (Top-%d)' %(topk))
indices, acc = memory_bank_base.mine_nearest_neighbors(topk)
print('Accuracy of top-%d nearest neighbors on train set is %.2f' %(topk, 100*acc))
np.save(p['topk_neighbors_train_path'], indices)
# Mine the topk nearest neighbors at the very end (Val)
# These will be used for validation.
print(colored('Fill memory bank for mining the nearest neighbors (val) ...', 'blue'))
fill_memory_bank(val_dataloader, model, memory_bank_val)
topk = 5
print('Mine the nearest neighbors (Top-%d)' %(topk))
indices, acc = memory_bank_val.mine_nearest_neighbors(topk)
print('Accuracy of top-%d nearest neighbors on val set is %.2f' %(topk, 100*acc))
np.save(p['topk_neighbors_val_path'], indices)
if __name__ == '__main__':
main()