-
Notifications
You must be signed in to change notification settings - Fork 3
/
batch_norm.py
98 lines (82 loc) · 2.35 KB
/
batch_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import torch
from torch import nn
from torch.utils import mkldnn as mkldnn_utils
from time import time
import copy
iters = 1000
bs=1
parser = argparse.ArgumentParser(description='BatchNorm2d')
parser.add_argument('--train', action='store_true', default=False,
help='do training')
args = parser.parse_args()
def run_single_test(sz):
n, c, h, w = bs, sz[1], sz[2], sz[3]
niters = iters
if args.train:
niters = int(iters / 10)
input = torch.randn(n, c, h, w)
grad_output = torch.randn(n, c, h, w)
m = nn.BatchNorm2d(c)
if args.train:
m.train()
else:
m.eval()
# channels last
input2 = input.clone().to(memory_format=torch.channels_last)
m2 = copy.deepcopy(m).to(memory_format=torch.channels_last)
grad_output2 = grad_output.clone().to(memory_format=torch.channels_last)
# blocked
#input3 = input.clone().to_mkldnn()
#m3 = mkldnn_utils.to_mkldnn(m)
for i in range(int(niters/10)):
output = m(input)
t1 = time()
if args.train:
for i in range(niters):
input.requires_grad_()
output = m(input)
output.backward(grad_output)
else:
for i in range(niters):
output = m(input)
t2 = time()
tt = (t2 - t1) / niters * 1000
for i in range(int(niters/10)):
output2 = m2(input2)
t3 = time()
if args.train:
for i in range(niters):
input.requires_grad_()
output2 = m2(input2)
output2.backward(grad_output2)
else:
for i in range(niters):
output2 = m2(input2)
t4 = time()
tt2 = (t4 - t3) / niters * 1000
#t5 = time()
#for i in range(niters):
# output3 = m3(input3)
#t6 = time()
#tt3 = (t6 - t5) / niters * 1000
print('BatchNorm size(contiguous): [{},{},{},{}]: {:.3f} ms'.format(n, c, h, w, tt))
print('BatchNorm size(channels last): [{},{},{},{}]: {:.3f} ms'.format(n, c, h, w, tt2))
#print('BatchNorm size(blocked: [{},{},{},{}]: {:.3f} ms'.format(n, c, h, w, tt3))
rn50_bn_sizes = [
[1, 64, 112, 112],
[1, 64, 56, 56],
[1, 256, 56, 56],
[1, 128, 56, 56],
[1, 128, 28, 28],
[1, 512, 28, 28],
[1, 256, 28, 28],
[1, 256, 14, 14],
[1, 1024, 14, 14],
[1, 256, 14, 14],
[1, 512, 14, 14],
[1, 512, 7, 7],
[1, 2048, 7, 7]
]
for sz in rn50_bn_sizes:
run_single_test(sz)