Skip to content

Commit

Permalink
po2 max_value and min_value fixed.
Browse files Browse the repository at this point in the history
Enhance initialize in qlayer.py.

PiperOrigin-RevId: 305970234
Change-Id: I5ad84a632ce0be6108ba13b8c554be4cbe34b038
  • Loading branch information
zhuangh authored and copybara-github committed Apr 10, 2020
1 parent 244ebe8 commit 61479f9
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 93 deletions.
2 changes: 1 addition & 1 deletion qkeras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
from .qpooling import * # pylint: disable=wildcard-import
from .safe_eval import * # pylint: disable=wildcard-import

__version__ = "0.7.0"
__version__ = "0.7.4"
2 changes: 1 addition & 1 deletion qkeras/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def extract_model_operations(model):

number_of_operations = (size_i * size_o)

number_of_weights = size_i * size_o
number_of_weights = size_i * size_o
number_of_bias = 0
if len(layer.get_weights()) > 1:
number_of_bias = layer.get_weights()[1].shape[0]
Expand Down
70 changes: 64 additions & 6 deletions qkeras/qlayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import warnings
import six
import tensorflow.compat.v2 as tf
Expand All @@ -42,9 +43,9 @@
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Layer
from .quantizers import get_quantized_initializer
from .quantizers import get_quantizer
from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer

Expand All @@ -64,21 +65,79 @@ def get_auto_range_constraint_initializer(quantizer, constraint, initializer):
initializer is initializer contraint by value range of quantizer.
"""
if quantizer is not None:
max_value = quantizer.max() if hasattr(quantizer, "max") else 1.0
# let's use now symmetric clipping function
max_value = max(1, quantizer.max()) if hasattr(quantizer, "max") else 1.0
min_value = quantizer.min() if hasattr(quantizer, "min") else -1.0

if constraint:
constraint = constraints.get(constraint)
constraint = Clip(min_value, max_value, constraint, quantizer)
initializer = get_quantized_initializer(initializer,
max(abs(min_value), abs(max_value)))

constraint = Clip(-max_value, max_value, constraint, quantizer)
initializer = initializers.get(initializer)
if initializer and initializer.__class__.__name__ not in ["Ones", "Zeros"]:
# we want to get the max value of the quantizer that depends
# on the distribution and scale
if not (hasattr(quantizer, "alpha") and
isinstance(quantizer.alpha, six.string_types)):
initializer = QInitializer(
initializer, use_scale=True, quantizer=quantizer)
return constraint, initializer


class QInitializer(Initializer):
"""Wraps around Keras initializer to provide a fanin scaling factor."""

def __init__(self, initializer, use_scale, quantizer):
self.initializer = initializer
self.use_scale = use_scale
self.quantizer = quantizer

try:
self.is_po2 = "po2" in quantizer.__class__.__name__
except:
self.is_po2 = False

def __call__(self, shape, dtype=None):
x = self.initializer(shape, dtype)

max_x = np.max(abs(x))
std_x = np.std(x)
delta = self.quantizer.max() * 2**-self.quantizer.bits

# delta is the minimum resolution of the number system.
# we want to make sure we have enough values.
if delta > std_x and hasattr(self.initializer, "scale"):
q = self.quantizer(x)
max_q = np.max(abs(q))
scale = 1.0
if max_q == 0.0:
xx = np.mean(x * x)
scale = self.quantizer.max() / np.sqrt(xx)
else:
qx = np.sum(q * x)
qq = np.sum(q * q)

scale = qq / qx

self.initializer.scale *= max(scale, 1)
x = self.initializer(shape, dtype)

return np.clip(x, -self.quantizer.max(), self.quantizer.max())

def get_config(self):
return {
"initializer": self.initializer,
"use_scale": self.use_scale,
"quantizer": self.quantizer,
}


#
# Because it may be hard to get serialization from activation functions,
# we may be replacing their instantiation by QActivation in the future.
#


class QActivation(Layer, PrunableLayer):
"""Implements quantized activation layers."""

Expand Down Expand Up @@ -309,4 +368,3 @@ def get_quantizers(self):

def get_prunable_weights(self):
return [self.kernel]

Loading

0 comments on commit 61479f9

Please sign in to comment.