-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_feature_density_estimator.py
161 lines (126 loc) · 6.13 KB
/
train_feature_density_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Evan Cook, [email protected]
# January 2024
from tqdm import tqdm
import pickle
import numpy as np
import torch
import random
import argparse
import sys
from datetime import datetime
import os
import normflows as nf # https://github.com/VincentStimper/normalizing-flows
import fde
import pickle as pk
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# To profile, measure start and end time using time.now()
time = datetime.utcnow()
def delta_t(a, b):
return abs(a-b).total_seconds()
####################################################################################################
# Configuration (may be overridden by arguments)
####################################################################################################
# Project constants
config = {
"flow_architecture": "Glow",
"model name": "ResNet50",
"dataset": "ImageNet1k",
"output path": "checkpoint/flow_resnet50.pt",
"try cuda": True,
"epoch count": 1,
"batch size": 250,
"val data count": 50000,
"learning rate": 1e-5,
"seed": 42,
"imagenet dir": os.path.join(os.environ['HOME'], "data", "ImageNet1k"),
}
####################################################################################################
# Training Code
####################################################################################################
# Inner training loop for flow_model - run one epoch through the dataloader
def train(flow_model, classifier_model, device, dataloader, optimizer):
flow_model.train()
classifier_model.eval()
losses = []
progress_bar = tqdm(dataloader)
for data, target in progress_bar:
# Reset gradients
optimizer.zero_grad()
# Load the data and labels from the training dataset
data, targets = data.to(device), target.to(device)
# Run the data through our backbone so we can extract the feature representations
with torch.no_grad():
logits, features = fde.get_features(classifier_model, data, device)
logprob = flow_model.log_prob(features)
loss = -torch.mean(logprob)
# Do backpropagation and perform gradient descent
loss.backward()
optimizer.step()
losses.append(float(loss))
progress_bar.set_description("Training, loss = {:.2f}".format(float(loss)))
return float(np.mean(losses))
def run():
print("Training {}...".format(config["model name"]))
################################################################################################
# Environment setup
################################################################################################
random.seed(config["seed"])
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
use_cuda = config["try cuda"] and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(">> Training device:", device)
kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}
################################################################################################
# Models
################################################################################################
# Create our classifier backbone
classifier_model = fde.create_classifier(config["model name"])
if classifier_model is None:
print("Unknown model name!")
sys.exit(1)
classifier_model = classifier_model.to(device)
classifier_model.eval()
# Create our flow model
base = nf.distributions.base.DiagGaussian(classifier_model.feature_dimension, trainable=False)
flows = fde.create_glow_flows(10, classifier_model.feature_dimension)
flow_model = nf.NormalizingFlow(base, flows).to(device)
################################################################################################
# Dataloader
################################################################################################
train_data = datasets.ImageNet(root=config['imagenet dir'], split='train', transform=classifier_model.transform)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=config["batch size"], shuffle=True, **kwargs)
################################################################################################
# Training loop
################################################################################################
optimizer = torch.optim.Adam(flow_model.parameters(), lr=config["learning rate"])
for epoch in range(config["epoch count"]):
tick = time.now()
epoch_train_loss = train(flow_model, classifier_model, device, train_dataloader, optimizer)
tock = time.now()
print(">> Epoch {} complete, {:.2f} seconds elapsed.".format(epoch, delta_t(tick, tock)))
torch.save(flow_model.state_dict(), config["output path"])
print(">> Complete.")
def main():
global config
parser = argparse.ArgumentParser(description="Script to train a normalizing flow density estimation model on the feature representations of a pretrained ImageNet1k backbone")
# Add command line arguments
parser.add_argument('output_path', type=str, help='Output path of the normalizing flow model')
parser.add_argument('model_name', type=str, help='Name of the pytorch pretrained ImageNet1k backbone model')
parser.add_argument('--seed', type=int, default=config["seed"], help='Random seed')
parser.add_argument('--epochs', type=int, default=config["epoch count"], help='Number of epochs to train the normalizing flow for (optional)')
parser.add_argument('--batch_size', type=int, default=config["batch size"], help='Training batch size (optional)')
args = parser.parse_args()
config["output path"] = args.output_path
config["model name"] = args.model_name
config["epoch count"] = args.epochs
config["batch size"] = args.batch_size
config["seed"] = args.seed
run()
if __name__ == "__main__":
main()