-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathengine.py
116 lines (81 loc) · 3.6 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
import utils
import torch.nn.functional as F
def train_one_epoch(model, criterion,
data_loader, optimizer,
device: torch.device, epoch):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Epoch: [{}]'.format(epoch)
print_freq = 100
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def train_one_epoch_distillation(teacher, student, criterion,
data_loader, optimizer,
device: torch.device, epoch, alpha=1.0, temp=1.0):
teacher.eval()
student.train()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Epoch: [{}]'.format(epoch)
print_freq = 100
kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
outputs = student(samples)
outputs_teacher = teacher(samples)
# Imeplemet knowledge distillation loss here
# *****************************************
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 20, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}