-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
63 lines (52 loc) · 3.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from torch.utils.data import DataLoader, Dataset
import os
from torchmetrics.functional.classification import multilabel_accuracy
from sklearn.metrics import roc_auc_score, multilabel_confusion_matrix, confusion_matrix
from libauc.losses import AUCMLoss, CrossEntropyLoss
from opacus import PrivacyEngine
from data import get_data
from numpy.linalg import norm
from torch.func import grad, grad_and_value, vmap
from opacus.grad_sample import GradSampleModule
from models import LogisticRegresion, CNN, TwoLayer
import config
import time
from sklearn.model_selection import ParameterGrid
import clip_train
import wrn_train
import scatternet_train
params = config.params
if params['dataset'].startswith('chexpert') or params['dataset'].startswith('eyepacs'):
params['output_dim'] = 5
elif params['dataset'].startswith('cifar10'):
params['output_dim'] = 10
else:
raise ValueError('dataset not implemented')
variables_dict = config.variable_parameters_dict
variables = list(ParameterGrid(variables_dict))
for conf in variables:
params.update(conf)
result_dir = config.RESULTS_DIR
if not os.path.exists(result_dir):
os.makedirs(result_dir)
time_stamp_now = time.strftime("%Y%m%d-%H%M%S")
if params['aug_multiplicity']:
result_file_excel = open(f'{result_dir}/{time_stamp_now}_{params["baseline"]}_{params["dataset"]}_aug_{params["n_augs"]}_{params["model"]}_epochs{params["epochs"]}_privacy{params["privacy"]}_reps{params["reps"]}_ema{params["ema_flag"]}_norm{params["norm_flag"]}_{params["group_norm_groups"]}_batchsize{params["minibatch_size"]}_excel.csv', 'w')
log_file = open(f'{result_dir}/{time_stamp_now}_{params["baseline"]}_{params["dataset"]}_aug_{params["n_augs"]}_{params["model"]}_epochs{params["epochs"]}_privacy{params["privacy"]}_reps{params["reps"]}_ema{params["ema_flag"]}_norm{params["norm_flag"]}{params["group_norm_groups"]}_batchsize{params["minibatch_size"]}_excel.txt', 'w')
else:
result_file_excel = open(f'{result_dir}/{time_stamp_now}_{params["baseline"]}_{params["dataset"]}_{params["model"]}_epochs{params["epochs"]}_privacy{params["privacy"]}_reps{params["reps"]}_ema{params["ema_flag"]}_norm{params["norm_flag"]}_{params["group_norm_groups"]}_batchsize{params["minibatch_size"]}_excel.csv', 'w')
log_file = open(f'{result_dir}/{time_stamp_now}_{params["baseline"]}_{params["dataset"]}_{params["model"]}_epochs{params["epochs"]}_privacy{params["privacy"]}_reps{params["reps"]}_ema{params["ema_flag"]}_norm{params["norm_flag"]}{params["group_norm_groups"]}_batchsize{params["minibatch_size"]}_excel.txt', 'w')
params['result_file_csv'] = result_file_excel
params['log_file'] = log_file
result_file_excel.write('name,epochs,privacy,epsilon,delta,max_grad_norm,optim,LR,WS,EMA,BatchSize,input_norm,train_auc,train_auc_std,test_auc,test_auc_std, test_acc\n')
if 'clip' in params['baseline']:
clip_train.train_clip(params)
elif 'wrn' in params['baseline']:
for i in range(params["reps"]):
if params["privacy"]:
wrn_train.wrn_train(params)
else:
wrn_nonprivate.wrn_train(params)
elif 'scatternet' in params['baseline']:
scatternet_train.train_scatternet(params)