From abc660f895f80e3e4386ae5f086e26c7a7169c07 Mon Sep 17 00:00:00 2001 From: Shan Li Date: Thu, 6 Feb 2025 18:21:53 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 724150081 Change-Id: I10ab2aa93c337cc64b998eb60ccdee6f521b3134 --- qkeras/qnormalization.py | 2 +- qkeras/quantizers.py | 254 ++++++++++++++++++++++---------------- qkeras/utils.py | 26 ++-- tests/qactivation_test.py | 198 +++++++++++++++++++---------- 4 files changed, 293 insertions(+), 187 deletions(-) diff --git a/qkeras/qnormalization.py b/qkeras/qnormalization.py index 90245e15..fbc0e51f 100644 --- a/qkeras/qnormalization.py +++ b/qkeras/qnormalization.py @@ -350,7 +350,7 @@ def get_config(self): 'beta_range': self.beta_range, 'gamma_range': self.gamma_range, } - base_config = super(QBatchNormalization, self).get_config() + base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def compute_output_shape(self, input_shape): diff --git a/qkeras/quantizers.py b/qkeras/quantizers.py index 54dd2672..c138a871 100644 --- a/qkeras/quantizers.py +++ b/qkeras/quantizers.py @@ -18,7 +18,7 @@ from __future__ import print_function import re -from typing import Any, List, Tuple +from typing import Any, List, Tuple, cast import numpy as np import six @@ -139,12 +139,14 @@ def _get_scaling_axis(scale_axis: Any, len_axis: int) -> List[int]: """ if scale_axis is not None: + # if scale_axis is set, scale over all axis except the scale_axis. if isinstance(scale_axis, list): axis = [i for i in range(len_axis) if i not in scale_axis] else: axis = tf.range(scale_axis) axis = tf.concat([axis, tf.range(scale_axis + 1, len_axis)], axis=0) else: + # if scale_axis is not set, scale over all axis except the channel axis. if K.image_data_format() == "channels_last": axis = tf.range(tf.math.maximum(len_axis - 1, 0)) else: @@ -435,8 +437,8 @@ def _clip_po2_scale(scale: tf.Tensor, min_po2_exponent: Any, def _get_least_squares_scale( - alpha: Any, x: tf.Tensor, q: tf.Tensor, scale_axis: Any = None, - per_channel_scale: bool = True, elements_per_scale: Any = None, + alpha: Any, x: tf.Tensor, q: tf.Tensor, scale_axis: Any = None, + per_channel_scale: bool = True, elements_per_scale: Any = None, min_po2_exponent: Any = None, max_po2_exponent: Any = None): """Gets scaling factor for scaling the tensor per channel. @@ -696,15 +698,15 @@ class quantized_linear(base_quantizer.BaseQuantizer): - Whether we want to have a symmetric range or not - Whether we want to keep negative numbers or not - The quantization scale is defined by either the quantizer parameters or the - data passed to the __call__ method. See documentation for the `alpha` - parameter to find out more. + The quantization scale is defined by either the quantizer parameters or the + data passed to the __call__ method. See documentation for the `alpha` + parameter to find out more. - For backprop purposes, the quantizer uses the straight-through estimator - for the rounding step (https://arxiv.org/pdf/1903.05662.pdf). Thus the - gradient of the __call__ method is 1 on the interval - [quantization_scale * clip_min, quantization_scale * clip_max] and 0 - elsewhere. + For backprop purposes, the quantizer uses the straight-through estimator + for the rounding step (https://arxiv.org/pdf/1903.05662.pdf). Thus the + gradient of the __call__ method is 1 on the interval + [quantization_scale * clip_min, quantization_scale * clip_max] and 0 + elsewhere. The quantizer also supports a number of other optional features: - Stochastic rounding (see the `stochastic_rounding` parameter) @@ -712,18 +714,18 @@ class quantized_linear(base_quantizer.BaseQuantizer): Notes on the various "scales" in quantized_linear: - - The quantization scale is the scale used in the core computation (see - above). You can access it via the `quantization_scale` attribute. - - The data type scale is the scale is determined by the type of data - stored on hardware on a small device running a true quantized model. - It is the quantization scale needed to represent `bits` bits, `integer` - of which are integer bits, and one bit is reserved for the sign if - `keep_negative` is True. It can be calculated as - 2 ** (integer - bits + keep_negative). You can access it via the - `data_type_scale` attribute. - - The `scale` attribute stores the quotient of the quantization scale and - the data type scale. This is also the scale that can be directly - specified by the user, via the `alpha` parameter. + - The quantization scale is the scale used in the core computation (see + above). You can access it via the `quantization_scale` attribute. + - The data type scale is the scale is determined by the type of data + stored on hardware on a small device running a true quantized model. + It is the quantization scale needed to represent `bits` bits, `integer` + of which are integer bits, and one bit is reserved for the sign if + `keep_negative` is True. It can be calculated as + 2 ** (integer - bits + keep_negative). You can access it via the + `data_type_scale` attribute. + - The `scale` attribute stores the quotient of the quantization scale and + the data type scale. This is also the scale that can be directly + specified by the user, via the `alpha` parameter. These three quantities are related by the equation scale = quantization_scale / data_type_scale. @@ -746,19 +748,19 @@ class quantized_linear(base_quantizer.BaseQuantizer): | activation | +------------------------------------------------------------------------+ - # TODO: The only fundamentally necessary scale is the quantization scale. - # We should consider removing the data type scale and scale attributes, - # but know that this will require rewriting much of how qtools and HLS4ML - # use these scale attributes. + # TODO: The only fundamentally necessary scale is the quantization scale. + # We should consider removing the data type scale and scale attributes, + # but know that this will require rewriting much of how qtools and HLS4ML + # use these scale attributes. - Note on binary quantization (bits=1): - The core computation is modified here when `keep_negative` is True to - perform a scaled sign function. This is needed because the core - computation as defined above requires that 0 be mapped to 0, which does - not allow us to keep both positive and negative outputs for binary - quantization. Special shifting operations are used to achieve this. + Note on binary quantization (bits=1): + The core computation is modified here when `keep_negative` is True to + perform a scaled sign function. This is needed because the core + computation as defined above requires that 0 be mapped to 0, which does + not allow us to keep both positive and negative outputs for binary + quantization. Special shifting operations are used to achieve this. - Example usage: + Example usage: # 8-bit quantization with 3 integer bits >>> q = quantized_linear(8, 3) @@ -779,41 +781,45 @@ class quantized_linear(base_quantizer.BaseQuantizer): >>> q_fixed(x) array([0., 0., 0., 2., 2.], dtype=float32) - Args: - bits (int): Number of bits to represent the number. Defaults to 8. - integer (int): Number of bits to the left of the decimal point, used for - data_type_scale. Defaults to 0. - symmetric (bool): If true, we will have the same number of values for - positive and negative numbers. Defaults to True. - alpha (str, Tensor, None): Instructions for determining the quantization - scale. Defaults to None. - If None: the quantization scale is the data - type scale, determined by `integer`, `bits`, and `keep_negative`. - If - "auto", the quantization scale is calculated as the minimum floating point - scale per-channel that does not clip the max of x. - If "auto_po2", the - quantization scale is chosen as the power of two per-channel that - minimizes squared error between the quantized x and the original x. - If - Tensor: The quantization scale is the Tensor passed in multiplied by the - data type scale. - keep_negative (bool): If false, we clip negative numbers. Defaults to True. - use_stochastic_rounding (bool): If true, we perform stochastic rounding - (https://arxiv.org/pdf/1502.02551.pdf). - scale_axis (int, None): Which axis to calculate scale from. If None, we - perform per-channel scaling based off of the image data format. Note that - each entry of a rank-1 tensor is considered its own channel by default. - See `_get_scaling_axis` for more details. Defaults to None. - qnoise_factor (float): A scalar from 0 to 1 that represents the level of - quantization noise to add. This controls the amount of the quantization - noise to add to the outputs by changing the weighted sum of (1 - - qnoise_factor) * unquantized_x + qnoise_factor * quantized_x. Defaults to - 1.0, which means that the result is fully quantized. - use_variables (bool): If true, we use tf.Variables to store certain - parameters. See the base_quantizer.BaseQuantizer implementation for more - details. Defaults to False. If set to True, be sure to use the special - attribute update methods detailed in the base_quantizer.BaseQuantizer. - var_name (str or None): A variable name shared between the tf.Variables - created in on initialization, if use_variables is true. If None, the - variable names are generated automatically based on the parameter names - along with a uid. Defaults to None. + Args: + bits (int): Number of bits to represent the number. Defaults to 8. + integer (int): Number of bits to the left of the decimal point, used for + data_type_scale. Defaults to 0. + symmetric (bool): If true, we will have the same number of values + for positive and negative numbers. Defaults to True. + alpha (str, Tensor, None): Instructions for determining the quantization + scale. Defaults to None. + - If None: the quantization scale is the data type scale, determined + by `integer`, `bits`, and `keep_negative`. + - If "auto", the quantization scale is calculated as the minimum + floating point scale per-channel that does not clip the max of x. + - If "auto_po2", the quantization scale is chosen as the + power of two per-channel that minimizes squared error between the + quantized x and the original x. + - If Tensor: The quantization scale is the Tensor passed in + multiplied by the data type scale. + keep_negative (bool): If false, we clip negative numbers. Defaults to + True. + use_stochastic_rounding (bool): If true, we perform stochastic rounding + (https://arxiv.org/pdf/1502.02551.pdf). + scale_axis (int, None): Which axis to calculate scale from. If None, we + perform per-channel scaling based off of the image data format. Note + that each entry of a rank-1 tensor is considered its own channel by + default. See `_get_scaling_axis` for more details. Defaults to None. + qnoise_factor (float): A scalar from 0 to 1 that represents the level of + quantization noise to add. This controls the amount of the + quantization noise to add to the outputs by changing the weighted + sum of (1 - qnoise_factor) * unquantized_x + qnoise_factor * + quantized_x. Defaults to 1.0, which means that the result is fully + quantized. + use_variables (bool): If true, we use tf.Variables to store certain + parameters. See the BaseQuantizer implementation for more details. + Defaults to False. If set to True, be sure to use the special attribute + update methods detailed in the BaseQuantizer. + var_name (str or None): A variable name shared between the tf.Variables + created in on initialization, if use_variables is true. If None, the + variable names are generated automatically based on the parameter names + along with a uid. Defaults to None. Returns: function: Function that computes linear quantization. @@ -991,7 +997,7 @@ def __call__(self, x): def _scale_clip_and_round(self, x, quantization_scale): """Scale, clip, and round x to an integer value in a limited range - Note that the internal shift is needed for 1-bit quantization to ensure + Note that the internal shift is needed for 1-bit quantization to ensure that a sign function is used. Otherise, the binary quantizer would have three output values""" @@ -1027,10 +1033,10 @@ def _get_auto_quantization_scale(self, x): self.quantization_scale = tf.stop_gradient(quantization_scale) # very important that return value is a tf.Variable with shape None - return self.quantization_scale + return self.quantization_scale def _get_quantization_scale_from_max_data(self, x): - """Get the minimum floating point scale that does not clip the max + """Get the minimum floating point scale that does not clip the max of x""" axis = _get_scaling_axis(self.scale_axis, tf.rank(x)) @@ -1053,9 +1059,9 @@ def _po2_autoscale(self, x, quantization_scale): """Get an approximation of the "best" po2 scale using least squares""" # set alpha scale to a near power of two - quantization_scale = K.pow(2.0, - tf.math.round(K.log(quantization_scale + K.epsilon()) / - K.log(2.0))) + quantization_scale = K.pow( + 2.0, + tf.math.round(K.log(quantization_scale + K.epsilon()) / K.log(2.0))) def loop_body(_, quantization_scale): """Loop body for least squares autoscaling""" @@ -1070,7 +1076,7 @@ def loop_body(_, quantization_scale): return quantization_scale, new_quantization_scale def loop_cond(last_quantization_scale, quantization_scale): - """Loop condition for least squares autoscaling- stop when the + """Loop condition for least squares autoscaling- stop when the scale converges""" tensors_not_equal = tf.math.reduce_any( @@ -1129,8 +1135,8 @@ def __str__(self): # Main parameters always printed in string flags = [ - str(int(self.bits)), - str(int(self.integer)), + str(int(self.bits)), + str(int(self.integer)), str(int(self.symmetric))] # Optional parameters only printed if not default if not self.keep_negative: @@ -1234,6 +1240,8 @@ class quantized_bits(base_quantizer.BaseQuantizer): # pylint: disable=invalid-n allowed power of two exponent. max_po2_exponent: if set while using "auto_po2", it represents the maximum allowed power of two exponent. + post_training_scale: if set, it represents the scale value to be used for + quantization. Returns: Function that computes fixed-point quantization with bits. @@ -1253,7 +1261,8 @@ def __init__(self, use_variables=False, elements_per_scale=None, min_po2_exponent=None, - max_po2_exponent=None): + max_po2_exponent=None, + post_training_scale=None): super().__init__() self.bits = bits @@ -1262,10 +1271,22 @@ def __init__(self, self.keep_negative = keep_negative self.alpha = alpha self.use_stochastic_rounding = use_stochastic_rounding + self.post_training_scale = post_training_scale # "auto*" |-> symmetric if isinstance(self.alpha, six.string_types): + self.freeze_scale = False self.symmetric = True - self.scale = None + if post_training_scale is not None: + self.scale = np.array(post_training_scale) + self.freeze_scale = True + else: + if post_training_scale is not None: + raise ValueError(f"alpha={alpha} doesn't support post_training_scale: " + f"{post_training_scale}") + self.scale = None + # If alpha is not "auto*", then scale is fixed and not trainable. + self.freeze_scale = True + self.scale_axis = scale_axis self.qnoise_factor = qnoise_factor self.use_ste = use_ste @@ -1346,29 +1367,43 @@ def __call__(self, x): # using 2's complement. levels = (2**(self.bits-1)-1) * 2 if self.symmetric else (2**self.bits)-1 - scale = (K.max(abs(x), axis=axis, keepdims=True) * 2) / levels - - # If alpha is "auto_po2", then get the "best" po2 scale - if "po2" in self.alpha: - scale = K.pow(2.0, - tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) - for idx in range(5): - v = tf.floor(tf.abs(x) / scale + 0.5) - mask = v < levels / 2 - z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * levels / 2) - scale = _get_least_squares_scale(alpha="auto_po2", x=x, q=z, - scale_axis=self.scale_axis, - elements_per_scale=self.elements_per_scale, - min_po2_exponent=self.min_po2_exponent, - max_po2_exponent=self.max_po2_exponent) - - # If alpha is "auto", then get the "best" floating point scale - elif self.alpha == "auto": - v = tf.floor(tf.abs(x) / scale + 0.5) - mask = v < levels / 2 - z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * levels / 2) + if self.freeze_scale: + # Scale is fixed value. In this case, scale is extracted from the + # post-training quantizater scale. In order to retrain models with + # this scale value, we need to divide it by m to make it in the same + # value scale as x. + scale = self.scale / m else: - raise ValueError(f"Invalid alpha '{self.alpha}'") + # Calculate the scale. + scale = (K.max(abs(x), axis=axis, keepdims=True) * 2) / levels + + # If alpha is "auto_po2", then get the "best" po2 scale + if "po2" in self.alpha: + scale = K.pow(2.0, + tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) + for idx in range(5): + v = tf.floor(tf.abs(x) / scale + 0.5) + mask = v < levels / 2 + z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * levels / 2) + scale = _get_least_squares_scale( + alpha="auto_po2", x=x, q=z, + scale_axis=self.scale_axis, + elements_per_scale=self.elements_per_scale, + min_po2_exponent=self.min_po2_exponent, + max_po2_exponent=self.max_po2_exponent) + + elif self.alpha != "auto": + # If alpha is "auto", then directly uuse the "best" + # floating point scale. + raise ValueError(f"Invalid alpha '{self.alpha}'") + + # Even for trainable scale, we still need to quantize x with the best + # scale. This extra step is needed to ensure that with the same input + # and scale, the quantized output is identical between training and + # inference. + v = tf.floor(tf.abs(x) / scale + 0.5) + mask = v < levels / 2 + z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * levels / 2) # z is an integer number, so we must make the scale * m and z / m scale = scale * m @@ -1382,7 +1417,8 @@ def __call__(self, x): # return x + tf.stop_gradient(-x + scale * z) x = m_i * x xq = m_i * z / m - self.scale = scale + if not self.freeze_scale: + self.scale = scale xq = scale * xq if self.use_ste: @@ -1418,6 +1454,7 @@ def __call__(self, x): def _set_trainable_parameter(self): if self.alpha is None: self.alpha = "auto_po2" + self.freeze_scale = False self.symmetric = True def max(self): @@ -1461,6 +1498,10 @@ def range(self): @classmethod def from_config(cls, config): + # Convert JSON-serializable lists back to NumPy arrays. + if config.get("post_training_scale") is not None: + config["post_training_scale"] = np.array(config["post_training_scale"]) + return cls(**config) def get_config(self): @@ -1480,7 +1521,12 @@ def get_config(self): self.use_stochastic_rounding, "qnoise_factor": self.qnoise_factor.numpy() if isinstance( - self.qnoise_factor, tf.Variable) else self.qnoise_factor + self.qnoise_factor, tf.Variable) else self.qnoise_factor, + "post_training_scale": + # Since NumPy arrays are not directly JSON-serializable, + # we convert them to lists. + (self.post_training_scale.tolist() if self.post_training_scale is + not None else None) } return config diff --git a/qkeras/utils.py b/qkeras/utils.py index d7262e85..4c378ae7 100644 --- a/qkeras/utils.py +++ b/qkeras/utils.py @@ -281,7 +281,10 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}): has_scale = False enable_bn_fusing = False - if (isinstance(layer, QBatchNormalization) and + # isinstance() might fail due to inconsistent module import path. + # Use __class__.__name__ instead. + layer_class = layer.__class__.__name__ + if (layer_class == "QBatchNormalization" and layer.name in bn_layers_to_skip): # Mark current bn layer to be fused with the previous layer enable_bn_fusing = True @@ -334,10 +337,14 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}): unsigned_bits = quantizer.bits - quantizer.keep_negative m = K.cast_to_floatx(pow(2, unsigned_bits)) m_i = K.cast_to_floatx(K.pow(2, quantizer.integer)) - assert hasattr(quantizer.scale, "numpy"), ( - "The auto_po2 quantizer has to be called first in order to know " - "the values of scale.") - scale = K.cast_to_floatx(quantizer.scale.numpy()) + + assert hasattr(quantizer.scale, "numpy") or isinstance( + quantizer.scale, np.ndarray), ( + "The auto_po2 quantizer has to be called first in order " + "to know the values of scale.") + scale = quantizer.scale if isinstance( + quantizer.scale, np.ndarray) else quantizer.scale.numpy() + scale = K.cast_to_floatx(scale) # Make sure scale is power of 2 values log2val = np.log2(scale) diff = np.round(log2val) - log2val @@ -1078,15 +1085,6 @@ def clone_model(model, custom_objects=None): return qmodel - config = { - "class_name": model.__class__.__name__, - "config": model.get_config(), - } - clone = tf.keras.models.model_from_config( - config, custom_objects=custom_objects) - clone.set_weights(model.get_weights()) - return clone - def quantized_model_from_json(json_string, custom_objects=None): if not custom_objects: diff --git a/tests/qactivation_test.py b/tests/qactivation_test.py index daa4a06d..b0f09fa7 100644 --- a/tests/qactivation_test.py +++ b/tests/qactivation_test.py @@ -18,10 +18,12 @@ from __future__ import division from __future__ import print_function import numpy as np -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_array_equal import pytest +from tensorflow import keras from tensorflow.keras import backend as K +import tempfile from qkeras import set_internal_sigmoid from qkeras import binary @@ -85,7 +87,8 @@ ), (4, 4, 0, 0, "floor", np.array([[-7, -0.12, -0.03, 0, 0.01, 5]], dtype=K.floatx()), - np.array([[-4, -0.0625, -0.0625, 0.0625, 0.0625, 4]], dtype=K.floatx()), + np.array([[-4, -0.0625, -0.0625, 0.0625, 0.0625, 4]], + dtype=K.floatx()), ), (4, None, 0, 1, "floor", np.array( @@ -112,13 +115,14 @@ dtype=K.floatx()), ), ]) -def disable_test_quantized_po2(bits, - max_value, - use_stochastic_rounding, - quadratic_approximation, - log2_rounding, - test_values, - expected_values): +def disable_test_quantized_po2( + bits, + max_value, + use_stochastic_rounding, + quadratic_approximation, + log2_rounding, + test_values, + expected_values): """Test quantized_po2 function.""" x = K.placeholder(ndim=2) f = K.function([x], [quantized_po2( @@ -246,7 +250,7 @@ def ref_hard_sigmoid(y): "hard", False, np.array( - [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], + [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], dtype=K.floatx()), np.array([[0.015625, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875]], dtype=K.floatx()), @@ -256,9 +260,10 @@ def ref_hard_sigmoid(y): "smooth", False, np.array( - [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], + [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], dtype=K.floatx()), - np.array([[0.3125, 0.359375, 0.40625, 0.453125, 0.5, 0.546875, 0.59375, 0.640625]], + np.array([[0.3125, 0.359375, 0.40625, 0.453125, + 0.5, 0.546875, 0.59375, 0.640625]], dtype=K.floatx()), ), ( @@ -266,18 +271,22 @@ def ref_hard_sigmoid(y): "real", True, np.array( - [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], + [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], dtype=K.floatx()), - np.array([[0.265625, 0.328125, 0.375, 0.4375, 0.5, 0.5625, 0.625, 0.671875]], + np.array([[0.265625, 0.328125, 0.375, 0.4375, 0.5, + 0.5625, 0.625, 0.671875]], dtype=K.floatx()), ), ]) -def test_quantized_sigmoid(bits, sigmoid_type, use_real_sigmoid, test_values, expected_values): +def test_quantized_sigmoid(bits, sigmoid_type, use_real_sigmoid, + test_values, expected_values): """Test quantized_sigmoid function with three different sigmoid variants.""" set_internal_sigmoid(sigmoid_type) x = K.placeholder(ndim=2) - f = K.function([x], [quantized_sigmoid(bits, symmetric=True, use_real_sigmoid=use_real_sigmoid)(x)]) + f = K.function([x], + [quantized_sigmoid(bits, symmetric=True, + use_real_sigmoid=use_real_sigmoid)(x)]) set_internal_sigmoid(_default_sigmoid_type) result = f([test_values])[0] @@ -291,7 +300,7 @@ def test_quantized_sigmoid(bits, sigmoid_type, use_real_sigmoid, test_values, ex "hard", False, np.array( - [-15, 15], + [-15, 15], dtype=K.floatx()), np.array([0.0625, 0.9375], dtype=K.floatx()), @@ -301,7 +310,7 @@ def test_quantized_sigmoid(bits, sigmoid_type, use_real_sigmoid, test_values, ex "smooth", False, np.array( - [-15, 15], + [-15, 15], dtype=K.floatx()), np.array([0.0625, 0.9375], dtype=K.floatx()), @@ -311,25 +320,30 @@ def test_quantized_sigmoid(bits, sigmoid_type, use_real_sigmoid, test_values, ex "real", True, np.array( - [-15, 15], + [-15, 15], dtype=K.floatx()), np.array([0.0625, 0.9375], dtype=K.floatx()), ), ]) -def test_quantized_sigmoid_limits(bits, sigmoid_type, use_real_sigmoid, test_values, expected_values): +def test_quantized_sigmoid_limits( + bits, sigmoid_type, use_real_sigmoid, test_values, expected_values): """Test the min and max values of quantized_sigmoid function with three different sigmoid variants.""" set_internal_sigmoid(sigmoid_type) x = K.placeholder(ndim=2) - f = K.function([x], [quantized_sigmoid(bits, symmetric=True, use_real_sigmoid=use_real_sigmoid)(x)]) + f = K.function([x], + [quantized_sigmoid(bits, symmetric=True, + use_real_sigmoid=use_real_sigmoid)(x)]) set_internal_sigmoid(_default_sigmoid_type) result = f([test_values])[0] min_max = np.array( - [quantized_sigmoid(bits, symmetric=True, use_real_sigmoid=use_real_sigmoid).min(), - quantized_sigmoid(bits, symmetric=True, use_real_sigmoid=use_real_sigmoid).max()]) + [quantized_sigmoid(bits, symmetric=True, + use_real_sigmoid=use_real_sigmoid).min(), + quantized_sigmoid(bits, symmetric=True, + use_real_sigmoid=use_real_sigmoid).max()]) assert_allclose(result, expected_values, rtol=1e-05) assert_allclose(result, min_max, rtol=1e-05) @@ -341,18 +355,18 @@ def test_quantized_sigmoid_limits(bits, sigmoid_type, use_real_sigmoid, test_val 4, False, np.array( - [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], + [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], dtype=K.floatx()), - np.array([[-0.875, -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], + np.array([[-0.875, -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], dtype=K.floatx()), ), ( 4, True, np.array( - [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], + [[-1., -0.75, -0.5, -0.25, 0., 0.25, 0.5, 0.75]], dtype=K.floatx()), - np.array([[-0.75, -0.625, -0.5, -0.25, 0., 0.25, 0.5, 0.625]], + np.array([[-0.75, -0.625, -0.5, -0.25, 0., 0.25, 0.5, 0.625]], dtype=K.floatx()), ) ]) @@ -362,7 +376,8 @@ def test_quantized_tanh(bits, use_real_tanh, test_values, expected_values): set_internal_sigmoid('hard') x = K.placeholder(ndim=2) - f = K.function([x], [quantized_tanh(bits, symmetric=True, use_real_tanh=use_real_tanh)(x)]) + f = K.function([x], [quantized_tanh( + bits, symmetric=True, use_real_tanh=use_real_tanh)(x)]) set_internal_sigmoid(_default_sigmoid_type) result = f([test_values])[0] @@ -402,18 +417,20 @@ def test_quantized_tanh(bits, use_real_tanh, test_values, expected_values): dtype=K.floatx()), ), ]) -def test_quantized_tanh_limits(bits, sigmoid_type, use_real_tanh, test_values, expected_values): +def test_quantized_tanh_limits(bits, sigmoid_type, use_real_tanh, test_values, + expected_values): """Test the min and max values of quantized_tanh function with three different sigmoid variants.""" set_internal_sigmoid(sigmoid_type) x = K.placeholder(ndim=2) - f = K.function([x], [quantized_tanh(bits, symmetric=True, use_real_tanh=use_real_tanh)(x)]) + f = K.function([x], [quantized_tanh( + bits, symmetric=True, use_real_tanh=use_real_tanh)(x)]) set_internal_sigmoid(_default_sigmoid_type) result = f([test_values])[0] min_max = np.array( - [quantized_tanh(bits, symmetric=True, use_real_tanh=use_real_tanh).min(), - quantized_tanh(bits, symmetric=True, use_real_tanh=use_real_tanh).max()]) + [quantized_tanh(bits, symmetric=True, use_real_tanh=use_real_tanh).min(), + quantized_tanh(bits, symmetric=True, use_real_tanh=use_real_tanh).max()]) assert_allclose(result, expected_values, rtol=1e-05) assert_allclose(result, min_max, rtol=1e-05) @@ -439,8 +456,7 @@ def test_quantized_tanh_limits(bits, sigmoid_type, use_real_tanh, test_values, e 0.544922, 1.046875, 0.586899, 3.367188, 3.804688, 0.312500, 0.062500, 0.562500, 0.375000, 3.367188, 1.046875, 2.796875, 0.054688, 1.562500, 2.562500 - ]], - dtype=K.floatx()), + ]], dtype=K.floatx()), np.array([[ 0.500000, 0.625000, 0.250000, 1.500000, 0.000000, 3.937500, 3.937500, 3.937500, 0.375000, 0.875000, 0.750000, 3.937500, @@ -448,8 +464,7 @@ def test_quantized_tanh_limits(bits, sigmoid_type, use_real_tanh, test_values, e 0.500000, 1.000000, 0.625000, 3.375000, 3.750000, 0.250000, 0.000000, 0.500000, 0.375000, 3.375000, 1.000000, 2.750000, 0.000000, 1.500000, 2.500000 - ]], - dtype=K.floatx())), + ]], dtype=K.floatx())), ]) def test_quantized_relu(bits, integer, use_sigmoid, test_values, expected_values): """Test quantized_relu function.""" @@ -518,6 +533,53 @@ def test_quantized_bits(bits, integer, symmetric, keep_negative, test_values, assert_allclose(result, expected_values, rtol=rtol) +@pytest.mark.parametrize( + "bits, integer, expected_output, expected_scale", + [(4, 2, + [[0.25, 3.0, 0.09375, 0.25], [0.4375, 0.0, 0.21875, 1.5]], + [[0.125, 1., 0.0625, 0.5]]), + (4, 1, [[0.25, 3., 0.09375, 0.25], [0.4375, 0., 0.21875, 1.5]], + [[0.25, 2., 0.125, 1.]]), + (5, 2, + [[0.21875, 2.75, 0.09375, 0.375], [0.46875, 0.25, 0.234375, 1.375]], + [[0.125, 1, 0.0625, 0.5]]), + ]) +def test_quantized_bits_with_auto_po2_scale( + bits, integer, expected_output, expected_scale): + # Test if quantizer with the fixed scale works properly. + x = np.array([[0.23, 2.76, 0.1, 0.33], [0.53, 0.16, 0.3, 1.43]]) + + q = quantized_bits( + bits=bits, integer=integer, alpha="auto_po2") + q_out = q(x).numpy() + scale = q.scale.numpy() + + np.testing.assert_array_equal(q_out, expected_output) + np.testing.assert_array_equal(scale, expected_scale) + + +def test_quantized_bits_with_post_training_scale(): + # Test if quantizer with the fixed scale works properly. + np.random.seed(42) + array = np.random.uniform(low=0, high=10, size=(7, 64, 64, 3)) + + auto_po2_quantizer = quantized_bits( + bits=8, integer=3, alpha="auto_po2") + qw = auto_po2_quantizer(array) + auto_po2_scale = auto_po2_quantizer.scale.numpy() + alpha_ndarray_quantizer = quantized_bits( + bits=8, integer=3, alpha="auto_po2", + post_training_scale=auto_po2_scale) + + # Check if the scale is the same as auto_po2 quantizer. + np.testing.assert_array_equal(auto_po2_scale, + alpha_ndarray_quantizer.scale) + + qw_ndarray = alpha_ndarray_quantizer(array) + # Check if the quantized values are the same as auto_po2 quantizer. + np.testing.assert_array_equal(qw.numpy(), qw_ndarray.numpy()) + + @pytest.mark.parametrize('alpha, threshold, test_values, expected_values', [ (1.0, 0.33, np.array([[-3.0, -2.0, -1.0, -0.2, 0.0, 0.3, 1, 4, 10]], dtype=K.floatx()), @@ -652,26 +714,29 @@ def test_stochastic_binary_inference_mode(alpha, test_values, expected_values): @pytest.mark.parametrize( 'bound, alpha, temperature, expected_values, expected_scale', [ - ( - 0.01, - "auto", - 8, - np.array([-0.973, -0.903, -0.759, -0.574, -0.242, 0.161, 0.508, 0.723, - 0.874, 0.975]).astype(np.float32), - np.array([0.008427, 0.007001, 0.0057 , 0.004457, 0.003537, 0.003416, - 0.004507, 0.005536, 0.006853, 0.008282]).astype(np.float32) - ), - ( - 0.01, - "auto_po2", - 8, - np.array([-0.979, -0.877, -0.639, -0.586, -0.23 , 0.154, 0.327, 0.603, - 0.83 , 0.986]).astype(np.float32), - np.array([0.007812, 0.007812, 0.007812, 0.003906, 0.003906, 0.003906, - 0.007812, 0.007812, 0.007812, 0.007812]).astype(np.float32) - ) -]) -def test_stochastic_ternary(bound, alpha, temperature, expected_values, expected_scale): + ( + 0.01, + "auto", + 8, + np.array([-0.973, -0.903, -0.759, -0.574, -0.242, 0.161, 0.508, + 0.723, 0.874, 0.975]).astype(np.float32), + np.array([0.008427, 0.007001, 0.0057, 0.004457, 0.003537, 0.003416, + 0.004507, 0.005536, 0.006853, 0.008282] + ).astype(np.float32) + ), + ( + 0.01, + "auto_po2", + 8, + np.array([-0.979, -0.877, -0.639, -0.586, -0.23, 0.154, + 0.327, 0.603, 0.83, 0.986]).astype(np.float32), + np.array([0.007812, 0.007812, 0.007812, 0.003906, 0.003906, + 0.003906, 0.007812, 0.007812, 0.007812, 0.007812] + ).astype(np.float32) + ) + ]) +def test_stochastic_ternary(bound, alpha, temperature, expected_values, + expected_scale): np.random.seed(42) K.set_learning_phase(1) @@ -704,7 +769,8 @@ def test_stochastic_ternary(bound, alpha, temperature, expected_values, expected dtype=K.floatx()), np.array([[-10.0, -10.0, 0.0, 0, 0.0, 0.0, 0, 0, 10]], dtype=K.floatx())), ]) -def test_stochastic_ternary_inference_mode(alpha, threshold, test_values, expected_values): +def test_stochastic_ternary_inference_mode(alpha, threshold, test_values, + expected_values): K.set_learning_phase(0) x = K.placeholder(ndim=2) q = stochastic_ternary(alpha, threshold) @@ -719,34 +785,30 @@ def test_stochastic_ternary_inference_mode(alpha, threshold, test_values, expect # bits. The quantization is in asymmetric mode. ('bits, integer, symmetric, relu_shift, relu_upper_bound,' 'test_values, expected_values'), [ - ( - 6, 2, 0, 3, 6, + (6, 2, 0, 3, 6, np.array([[-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1, 4, 10]], dtype=K.floatx()), - np.array([[0., -0.375, -0.375, -0.25, 0., 0.25, 0.625, + np.array([[0., -0.375, -0.375, -0.25, 0., 0.25, 0.625, 3.875, 3.875]], dtype=K.floatx()), ), - ( - 6, 4, 1, 3, 6, + (6, 4, 1, 3, 6, np.array([[-10.0, -2.0, -2.3, -0.25, 0.0, 0.5, 1, 4, 10]], dtype=K.floatx()), np.array([[0., -0.5, -0.5, 0., 0., 0.5, 0.5, 4., 10.]], dtype=K.floatx()), ), - ( - 2, 0, 0, 3, 6, + (2, 0, 0, 3, 6, np.array([[-10.0, -2.0, -2.3, -0.25, 0.0, 0.5, 1, 4, 10]], dtype=K.floatx()), np.array([[0., -0.5, -0.5, 0., 0., 0.5, 0.5, 0.5, 0.5]], dtype=K.floatx()), - ), - ]) + ),]) def test_quantized_hswish(bits, integer, symmetric, relu_shift, relu_upper_bound, test_values, expected_values): x = K.placeholder(ndim=2) f = K.function( - [x], [quantized_hswish(bits, integer, symmetric,relu_shift=relu_shift, - relu_upper_bound=relu_upper_bound)(x)]) + [x], [quantized_hswish(bits, integer, symmetric, relu_shift=relu_shift, + relu_upper_bound=relu_upper_bound)(x)]) result = f([test_values])[0] assert_allclose(result, expected_values, rtol=1e-05)