-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcompute_madd.py
161 lines (117 loc) · 4.83 KB
/
compute_madd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
compute Multiply-Adds(MAdd) of each leaf module
"""
import torch.nn as nn
def compute_Conv2d_madd(module, inp, out):
assert isinstance(module, nn.Conv2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
in_c = inp.size()[1]
k_h, k_w = module.kernel_size
out_c, out_h, out_w = out.size()[1:]
groups = module.groups
# ops per output element
kernel_mul = k_h * k_w * (in_c // groups)
kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups)
kernel_add_group = kernel_add * out_h * out_w * (out_c // groups)
total_mul = kernel_mul_group * groups
total_add = kernel_add_group * groups
return total_mul + total_add
def compute_ConvTranspose2d_madd(module, inp, out):
assert isinstance(module, nn.ConvTranspose2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
in_c, in_h, in_w = inp.size()[1:]
k_h, k_w = module.kernel_size
out_c, out_h, out_w = out.size()[1:]
groups = module.groups
kernel_mul = k_h * k_w * (in_c // groups)
kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups)
kernel_add_group = kernel_add * in_h * in_w * (out_c // groups)
total_mul = kernel_mul_group * groups
total_add = kernel_add_group * groups
return total_mul + total_add
def compute_BatchNorm2d_madd(module, inp, out):
assert isinstance(module, nn.BatchNorm2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
in_c, in_h, in_w = inp.size()[1:]
# 1. sub mean
# 2. div standard deviation
# 3. mul alpha
# 4. add beta
return 4 * in_c * in_h * in_w
def compute_MaxPool2d_madd(module, inp, out):
assert isinstance(module, nn.MaxPool2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
if isinstance(module.kernel_size, (tuple, list)):
k_h, k_w = module.kernel_size
else:
k_h, k_w = module.kernel_size, module.kernel_size
out_c, out_h, out_w = out.size()[1:]
return (k_h * k_w - 1) * out_h * out_w * out_c
def compute_AvgPool2d_madd(module, inp, out):
assert isinstance(module, nn.AvgPool2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
if isinstance(module.kernel_size, (tuple, list)):
k_h, k_w = module.kernel_size
else:
k_h, k_w = module.kernel_size, module.kernel_size
out_c, out_h, out_w = out.size()[1:]
kernel_add = k_h * k_w - 1
kernel_avg = 1
return (kernel_add + kernel_avg) * (out_h * out_w) * out_c
def compute_ReLU_madd(module, inp, out):
assert isinstance(module, (nn.ReLU, nn.ReLU6))
count = 1
for i in inp.size()[1:]:
count *= i
return count
def compute_Softmax_madd(module, inp, out):
assert isinstance(module, nn.Softmax)
assert len(inp.size()) > 1
count = 1
for s in inp.size()[1:]:
count *= s
exp = count
add = count - 1
div = count
return exp + add + div
def compute_Linear_madd(module, inp, out):
assert isinstance(module, nn.Linear)
assert len(inp.size()) == 2 and len(out.size()) == 2
num_in_features = inp.size()[1]
num_out_features = out.size()[1]
mul = num_in_features
add = num_in_features - 1
return num_out_features * (mul + add)
def compute_Bilinear_madd(module, inp1, inp2, out):
assert isinstance(module, nn.Bilinear)
assert len(inp1.size()) == 2 and len(inp2.size()) == 2 and len(out.size()) == 2
num_in_features_1 = inp1.size()[1]
num_in_features_2 = inp2.size()[1]
num_out_features = out.size()[1]
mul = num_in_features_1 * num_in_features_2 + num_in_features_2
add = num_in_features_1 * num_in_features_2 + num_in_features_2 - 1
return num_out_features * (mul + add)
def compute_madd(module, inp, out):
if isinstance(module, nn.Conv2d):
return compute_Conv2d_madd(module, inp, out)
elif isinstance(module, nn.ConvTranspose2d):
return compute_ConvTranspose2d_madd(module, inp, out)
elif isinstance(module, nn.BatchNorm2d):
return compute_BatchNorm2d_madd(module, inp, out)
elif isinstance(module, nn.MaxPool2d):
return compute_MaxPool2d_madd(module, inp, out)
elif isinstance(module, nn.AvgPool2d):
return compute_AvgPool2d_madd(module, inp, out)
elif isinstance(module, (nn.ReLU, nn.ReLU6)):
return compute_ReLU_madd(module, inp, out)
elif isinstance(module, nn.Softmax):
return compute_Softmax_madd(module, inp, out)
elif isinstance(module, nn.Linear):
return compute_Linear_madd(module, inp, out)
elif isinstance(module, nn.Bilinear):
return compute_Bilinear_madd(module, inp[0], inp[1], out)
else:
print(f"[MAdd]: {type(module).__name__} is not supported!")
return 0