-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantize_ops.py
291 lines (257 loc) · 11.7 KB
/
quantize_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
'''
Copyright (c) 2019 [Jia-Yau Shiau]
Code work by Jia-Yau ([email protected]).
Code work is advised and forked from Peter Huang ([email protected])
--------------------------------------------------
Quantization operations and fully integral calculation for Weight-Drop LSTM cell.
This implementation is based on:
https://arxiv.org/pdf/1712.05877.pdf
"Quantization and Training of Neural Networks for
Efficient Integer-Arithmetic-Only Inference"
Benoit Jacob, Skirmantas Kligys, Bo Chen, Menglong Zhu,
Matthew Tang, Andrew Howard, Hartwig Adam, Dmitry Kalenichenko
The code is modified from tensorflow source code:
tf.quantization.quantize
'''
import re
import tensorflow as tf
from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common, quant_ops
from tensorflow.python.framework import dtypes, ops
from tensorflow.python.ops import clip_ops, control_flow_ops, math_ops
from tensorflow.python.platform import tf_logging as logging
def hard_sigmoid(x, name="hard_sigmoid"):
with ops.name_scope(name, "HardSigmoid", [x]) as name:
x = tf.add (tf.multiply(0.2, x), 0.5)
x = clip_ops.clip_by_value(x, 0.0, 1.0)
return x
def insert_quant_ops(ops_dict,
graph=None,
is_train=True,
weight_bits=8,
activation_bits=8,
ema_decay=0.999,
quant_delay=0,
vars_collection=None,
scope=None):
if graph is None:
graph = tf.get_default_graph()
if vars_collection is None:
vars_collection = tf.GraphKeys.GLOBAL_VARIABLES
matrix_context = _GetContextFromOp(ops_dict['lstm_matrix'])
producer = graph.get_operation_by_name(matrix_context + '/weights')
consumers = producer.outputs[0].consumers()
InsertQuantOp(context=matrix_context,
name='weights_quant',
producer=producer,
consumers=consumers,
is_training=is_train,
moving_avg=False,
ema_decay=ema_decay,
quant_delay=quant_delay,
narrow_range=True,
vars_collection=vars_collection,
bits=weight_bits,
consumer_scope=matrix_context)
producer = graph.get_operation_by_name(matrix_context + '/BiasAdd')
consumers = producer.outputs[0].consumers()
InsertQuantOp(context=matrix_context,
name='act_quant',
producer=producer,
consumers=consumers,
is_training=is_train,
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
init_min=0.0,
producer_scope=matrix_context)
post_activation_bypass_context = _GetContextFromOp(ops_dict['i'])
producer_list = [output.consumers()[0] for output in consumers[0].outputs]
try:
matrix_context = _GetContextFromOp(ops_dict['proj_kernel'])
producer = graph.get_operation_by_name(matrix_context + '/weights')
consumers = producer.outputs[0].consumers()
InsertQuantOp(context=matrix_context,
name='weights_quant',
producer=producer,
consumers=consumers,
is_training=is_train,
moving_avg=False,
ema_decay=ema_decay,
quant_delay=quant_delay,
narrow_range=True,
vars_collection=vars_collection,
bits=weight_bits,
consumer_scope=matrix_context)
except KeyError:
pass
stop_list = [post_activation_bypass_context + '/end_m',
post_activation_bypass_context + '/end_c']
while(producer_list != []):
producer_list_new = []
for producer in producer_list:
consumers = producer.outputs[0].consumers()
cond1 = ('hard_sigmoid' in producer.name) and ( not producer.name.endswith('clip_by_value'))
cond2 = ('Relu' in producer.name.split('/')[-1])
cond3 = ('hard_sigmoid' not in producer.name) and \
(producer.name.endswith('clip_by_value/Minimum'))
cond4 = producer.name.endswith('add') or producer.name.endswith('add_1')
if not (cond1 or cond2 or cond3 or cond4):
InsertQuantOp(context=post_activation_bypass_context,
name='post_activation',
producer=producer,
consumers=consumers,
is_training=is_train,
moving_avg=True,
ema_decay=ema_decay,
quant_delay=quant_delay,
vars_collection=vars_collection,
bits=activation_bits,
producer_scope=scope)
stop_list.append(producer.name)
for consumer in consumers:
if consumer not in producer_list_new and consumer.name not in stop_list:
producer_list_new.append(consumer)
producer_list = producer_list_new
def InsertQuantOp(context,
name,
producer,
consumers,
is_training,
moving_avg=True,
init_min=-6.0,
init_max=6.0,
bits=8,
ema_decay=0.999,
quant_delay=None,
vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
narrow_range=False,
producer_scope=None,
consumer_scope=None):
"""Inserts a quant op between a producer op and (multiple) consumer ops.
Args:
context: Context where producer and consumer operations are nested.
name: Name for the new quantization op within the context.
producer: Producer operation of the pairs where quantization will be
inserted.
consumers: Consumer operations of the pairs.
is_training: Whether quantizing training graph or eval graph.
moving_avg: Specifies whether to use exponential moving average or just
the last value seen.
init_min: Starting minimum value for the new quantization op.
init_max: Starting maximum value for the new quantization op.
bits: Number of bits to use for quantization, must be between 2 and 8.
ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update
quantization intervals for quantizing activations (see here about EMA:
https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
quant_delay: (Optional, default None) Int, count of global steps for which
to delay quantization. This helps weights stabilize at the start of
training.
vars_collection: (Optional) Collection where to store the variables for
quantization interval ends.
narrow_range: Whether to use the narrow quantization range
[1; 2^bits - 1] or wide range [0; 2^bits - 1].
producer_scope: The restriction of producer scope. If not None, the new op
will be inserted only when the producer is in this scope.
consumer_scope: The restriction of producer scope. If not None, the new op
will be inserted only when all the consumers are in this scope.
Raises:
ValueError: When producer operation is not directly connected to the
consumer operation.
"""
if producer_scope and not producer.name.startswith(producer_scope):
logging.info(
'InsertQuantOp ignores context="%s" name="%s" '
'because producer "%s" is not in scope "%s"',
context, name, producer.name, producer_scope)
return
if consumer_scope:
consumers_in_scope = []
for consumer in consumers:
if consumer.name.startswith(consumer_scope):
consumers_in_scope.append(consumer)
else:
logging.info(
'InsertQuantOp context="%s" name="%s" ignores '
'consumer "%s" because it is not in scope "%s"',
context, name, consumer.name, consumer_scope)
return
consumers = consumers_in_scope
name_prefix = _AddContextToName(context, name)
# This is needed on TPU where name_scope == 'TPUReplicate/loop', and
# name_prefix starts with 'TPUReplicate/loop/'; without dropping it
# variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
# breaks things later.
name_scope = ops.get_name_scope()
if name_scope:
name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')
inputs = producer.outputs[0]
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
fake_quant_ops = set([
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs'
])
if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
return
if moving_avg:
quant = (
quant_ops.MovingAvgQuantize(
inputs,
init_min=init_min,
init_max=init_max,
ema_decay=ema_decay,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
else:
quant = (
quant_ops.LastValueQuantize(
inputs,
init_min=init_min,
init_max=init_max,
is_training=is_training,
num_bits=bits,
narrow_range=narrow_range,
vars_collection=vars_collection,
name_prefix=name_prefix))
if quant_delay and quant_delay > 0:
# activate_quant = math_ops.greater_equal(
# common.CreateOrGetQuantizationStep(),
# quant_delay,
# name=name_prefix + '/activate_quant')
activate_quant = math_ops.greater_equal(
tf.get_default_graph().get_tensor_by_name('global_step:0'),
quant_delay,
name=name_prefix + '/activate_quant')
quant = control_flow_ops.cond(
activate_quant,
lambda: quant,
lambda: inputs,
name=name_prefix + '/delayed_quant')
if consumers:
tensors_modified_count = graph_editor.reroute_ts(
[quant], [inputs], can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
# of consumers.
if tensors_modified_count < len(consumers):
raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
[consumer.name for consumer in consumers]))
def _GetContextFromOp(op):
"""Gets the root context name from the op name."""
context_re = re.search(r'^(.*)/([^/]+)', op.name)
if context_re:
return context_re.group(1)
return ''
def _AddContextToName(context, name):
"""Adds the context to the name if it exists."""
if not context:
return name
return context + '/' + name