forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_qat.py
160 lines (146 loc) · 6.25 KB
/
test_qat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch.nn import Conv2d, BatchNorm2d, ReLU
from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d
from torch.quantization.QConfig import default_qat_qconfig
from torch.utils.mkldnn import disable_mkldnn_conv
from common_quantization import no_deadline
from common_utils import TestCase, run_tests
from hypothesis import given
from hypothesis import strategies as st
from functools import reduce
class IntrinsicQATModuleTest(TestCase):
# NOTE: Tests in this class are decorated with no_deadline
# to prevent spurious failures due to cuda runtime initialization.
@no_deadline
@given(batch_size=st.integers(2, 4),
input_channels_per_group=st.sampled_from([2, 3, 4]),
height=st.integers(5, 10),
width=st.integers(5, 10),
output_channels_per_group=st.sampled_from([2, 3]),
groups=st.integers(1, 3),
kernel_h=st.integers(1, 3),
kernel_w=st.integers(1, 3),
stride_h=st.integers(1, 2),
stride_w=st.integers(1, 2),
pad_h=st.integers(0, 2),
pad_w=st.integers(0, 2),
dilation=st.integers(1, 1),
padding_mode=st.sampled_from(['zeros', 'circular']),
use_relu=st.booleans(),
eps=st.sampled_from([1e-5, 1e-4, 1e-3]),
momentum=st.sampled_from([0.1, 0.2, 0.3]),
freeze_bn=st.booleans())
def test_conv_bn_relu(
self,
batch_size,
input_channels_per_group,
height,
width,
output_channels_per_group,
groups,
kernel_h,
kernel_w,
stride_h,
stride_w,
pad_h,
pad_w,
dilation,
padding_mode,
use_relu,
eps,
momentum,
freeze_bn
):
with disable_mkldnn_conv():
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
dilation_h = dilation_w = dilation
conv_op = Conv2d(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
False, # No bias
padding_mode
).to(dtype=torch.float)
bn_op = BatchNorm2d(output_channels, eps, momentum).to(dtype=torch.float)
relu_op = ReLU()
cls = ConvBnReLU2d if use_relu else ConvBn2d
qat_op = cls(
input_channels,
output_channels,
(kernel_h, kernel_w),
(stride_h, stride_w),
(pad_h, pad_w),
(dilation_h, dilation_w),
groups,
padding_mode,
eps,
momentum,
freeze_bn,
default_qat_qconfig
).to(dtype=torch.float).disable_fake_quant()
# align inputs and internal parameters
input = torch.randn(batch_size, input_channels, height, width, dtype=torch.float, requires_grad=True)
conv_op.weight = torch.nn.Parameter(qat_op.weight.detach())
bn_op.running_mean = qat_op.running_mean.clone()
bn_op.running_var = qat_op.running_var.clone()
bn_op.weight = torch.nn.Parameter(qat_op.gamma.detach())
bn_op.bias = torch.nn.Parameter(qat_op.beta.detach())
def compose(functions):
# functions are reversed for natural reading order
return reduce(lambda f, g: lambda x: f(g(x)), functions[::-1], lambda x: x)
if not use_relu:
def relu_op(x):
return x
if freeze_bn:
def ref_op(x):
x = conv_op(x)
x = (x - bn_op.running_mean.reshape([1, -1, 1, 1])) * \
(bn_op.weight / torch.sqrt(bn_op.running_var + bn_op.eps)) \
.reshape([1, -1, 1, 1]) + bn_op.bias.reshape([1, -1, 1, 1])
x = relu_op(x)
return x
else:
ref_op = compose([conv_op, bn_op, relu_op])
input_clone = input.clone().detach().requires_grad_()
for i in range(2):
result_ref = ref_op(input)
result_actual = qat_op(input_clone)
self.assertEqual(result_ref, result_actual)
# backward
dout = torch.randn(result_ref.size(), dtype=torch.float)
loss = (result_ref - dout).sum()
loss.backward()
input_grad_ref = input.grad.cpu()
weight_grad_ref = conv_op.weight.grad.cpu()
gamma_grad_ref = bn_op.weight.grad.cpu()
beta_grad_ref = bn_op.bias.grad.cpu()
running_mean_ref = bn_op.running_mean
running_var_ref = bn_op.running_var
num_batches_tracked_ref = bn_op.num_batches_tracked
loss = (result_actual - dout).sum()
loss.backward()
input_grad_actual = input_clone.grad.cpu()
weight_grad_actual = qat_op.weight.grad.cpu()
gamma_grad_actual = qat_op.gamma.grad.cpu()
beta_grad_actual = qat_op.beta.grad.cpu()
running_mean_actual = qat_op.running_mean
running_var_actual = qat_op.running_var
num_batches_tracked_actual = qat_op.num_batches_tracked
self.assertEqual(input_grad_ref, input_grad_actual)
self.assertEqual(weight_grad_ref, weight_grad_actual, prec=5e-4)
self.assertEqual(gamma_grad_ref, gamma_grad_actual, prec=1e-4)
self.assertEqual(beta_grad_ref, beta_grad_actual)
self.assertEqual(num_batches_tracked_ref, num_batches_tracked_actual)
self.assertEqual(running_mean_ref, running_mean_actual)
self.assertEqual(running_var_ref, running_var_actual)
if __name__ == '__main__':
run_tests()