forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_quantized_conv.py
120 lines (105 loc) · 5.13 KB
/
test_quantized_conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import unittest
import numpy as np
import torch
import torch.nn.quantized.functional as qF
from hypothesis import assume, given
from hypothesis import strategies as st
import hypothesis_utils as hu
from common_quantized import _conv_output_shape
from common_utils import TestCase, run_tests
@unittest.skipIf(
not torch.fbgemm_is_cpu_supported(),
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
class QuantizedConvTest(TestCase):
@given(X=hu.tensor_conv2d(min_batch=1, max_batch=3,
min_in_channels=1, max_in_channels=7,
min_out_channels=1, max_out_channels=7,
H_range=(6, 12), W_range=(6, 12),
kH_range=(3, 5), kW_range=(3, 5),
max_groups=4,
qparams=[hu.qparams(dtypes=torch.quint8,
zero_point_min=0,
zero_point_max=0),
hu.qparams(dtypes=torch.qint8,
zero_point_min=0,
zero_point_max=0),
hu.qparams(dtypes=torch.qint32,
zero_point_min=0,
zero_point_max=0)]),
padH=st.integers(1, 3), padW=st.integers(1, 3),
sH=st.integers(1, 3), sW=st.integers(1, 3),
dH=st.integers(1, 2), dW=st.integers(1, 2))
def test_conv_api(self, X, padH, padW, sH, sW, dH, dW):
"""Tests the correctness of the conv functional.
The correctness is defined by the behavior being similar to the
`quantized._ops` implementation.
"""
# Random inputs
# X, (scale, zero_point, torch_type) = X
(inputs, filters, bias, groups) = X
inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs
filters, (filters_scale, filters_zero_point, filters_qtype) = filters
bias, (bias_scale, bias_zero_point, bias_qtype) = bias
scale, zero_point = inputs_scale, inputs_zero_point
torch_type = inputs_qtype
iC, oC = inputs.shape[1], filters.shape[0]
iH, iW = inputs.shape[2:]
kH, kW = filters.shape[2:]
assume(kH // 2 >= padH)
assume(kW // 2 >= padW)
oH = _conv_output_shape(iH, kH, padH, sH, dH)
assume(oH > 0)
oW = _conv_output_shape(iW, kW, padW, sW, dW)
assume(oW > 0)
inputs = torch.from_numpy(inputs).to(torch.float)
filters = torch.from_numpy(filters).to(torch.float)
bias = torch.from_numpy(bias).to(torch.float)
kernel_size = (kH, kW)
stride = (sH, sW)
i_padding = (padH, padW)
dilation = (dH, dW)
# Quantized inputs
q_inputs = torch.quantize_linear(inputs, inputs_scale,
inputs_zero_point, inputs_qtype)
q_filters = torch.quantize_linear(filters, filters_scale,
filters_zero_point, filters_qtype)
q_bias = torch.quantize_linear(bias, bias_scale, bias_zero_point,
bias_qtype)
# Reference op
ref_op = torch.ops.quantized.fbgemm_conv2d
# Results check
try:
q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(q_filters.permute([0, 2, 3, 1]),
stride,
i_padding,
dilation,
groups)
ref_result = ref_op(q_inputs.permute([0, 2, 3, 1]), q_filters_ref,
q_bias, stride,
i_padding, dilation,
groups, scale, zero_point).permute([0, 3, 1, 2])
except RuntimeError as e:
e_msg = str(e).split("\n")[0].split("(")[0].strip()
np.testing.assert_raises_regex(
type(e), e_msg, qF.conv2d,
q_inputs, q_filters, bias=q_bias,
scale=scale, zero_point=zero_point,
stride=stride, padding=i_padding, dilation=dilation,
groups=groups, dtype=torch_type)
else:
q_result = qF.conv2d(q_inputs,
q_filters,
bias=q_bias, scale=scale,
zero_point=zero_point,
stride=stride, padding=i_padding,
dilation=dilation, groups=groups,
dtype=torch_type)
self.assertEqual(ref_result, q_result)
if __name__ == "__main__":
run_tests()