Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 725819350
Change-Id: Ie196effeb9912c7cdd38e53a2b69a1716147ba4c
  • Loading branch information
lishanok authored and copybara-github committed Feb 12, 2025
1 parent abc660f commit a1fec24
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 4 deletions.
186 changes: 182 additions & 4 deletions qkeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}):
# Weights store the weight in the format that software inference uses.
weights.append(weight)

q_name = ""
if quantizer:
if isinstance(quantizer, six.string_types):
q_name = quantizer
Expand All @@ -318,11 +319,10 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}):
q_name = quantizer.name
elif hasattr(quantizer, "__class__"):
q_name = quantizer.__class__.__name__
else:
q_name = ""

if quantizer and ("_po2" in q_name):
# Quantized_relu_po2 does not have a sign.
if isinstance(quantizer, quantized_po2):
if q_name == "quantized_po2":
has_sign = True
sign = np.sign(weight)
# Makes sure values are -1 or +1 only
Expand All @@ -332,7 +332,7 @@ def model_save_quantized_weights(model, filename=None, custom_objects={}):
hw_weight = np.round(np.log2(np.abs(weight)))
signs.append(sign)
scales.append([])
elif (isinstance(quantizer, quantized_bits) and
elif (q_name == "quantized_bits" and
quantizer.alpha == "auto_po2"):
unsigned_bits = quantizer.bits - quantizer.keep_negative
m = K.cast_to_floatx(pow(2, unsigned_bits))
Expand Down Expand Up @@ -1352,3 +1352,181 @@ def quantized_model_dump(model,
print("writing the layer output tensor to ", filename)
with open(filename, "w") as fid:
tensor_data.astype(np.float32).tofile(fid)


def clone_model_and_freeze_auto_po2_scale(
orig_model, orig_model_path=None, quantize_model_weights=False):
"""Clone model and freeze the scale value of auto_po2 type quantizers.
Args:
orig_model: original model which will be used to clone the new model.
If set to None, the function will load the original model
from orig_model_path argument.
orig_model_path: The path to the original model file.
If set to None, the function will load the original model from the
orig_model argument.
quantize_model_weights: Bool to quantize weights to HW format.
If set to False, the model weights will be in float format.
If set to True, the model weights will be in HW format and the function
will also check if the hw weights extracted from the new model matches
the original model.
Returns:
A tuple of the new model and the new model's hw weights.
Note:
+ When using this function to retrain model with fixed scale value.
Set quantize_model_weights to False in this case.
+ This function only supports a collection of common layers that will use
auto_po2 quantizers. For less common layers, it will raise errors and we
will add more support case by case.
Example usage:
model, _ = clone_model_and_freeze_auto_po2_scale(
orig_model_path="path/to/model",
quantize_model_weights=False)
"""

def _create_bn_layer(layer_cfg, bn_inv_quantizer):
# Clone batch normalization layer with the new inverse quantizer.
if bn_inv_quantizer is not None:
layer_cfg["inverse_quantizer"]["config"] = bn_inv_quantizer.get_config()
return QBatchNormalization(**layer_cfg)

def _create_qconv2d_layer(layer_cfg, kernel_quantizer):
# Clone QConv2D layer wiht the new kernel quantizers.
if kernel_quantizer is not None:
layer_cfg["kernel_quantizer"]["config"] = kernel_quantizer.get_config()
return QConv2D(**layer_cfg)

def _create_qdepthwise_conv2d_layer(layer_cfg, depthwise_quantizer):
# Clone QDepthwiseConv2D layer with the new depthwise_quantizer quantizer.
if depthwise_quantizer is not None:
layer_cfg["depthwise_quantizer"][
"config"] = depthwise_quantizer.get_config()
return QDepthwiseConv2D(**layer_cfg)

def _create_qdense_layer(layer_cfg, kernel_quantizer):
# Clone QDense layer with the new kernel quantizer.
if kernel_quantizer is not None:
layer_cfg["kernel_quantizer"]["config"] = kernel_quantizer.get_config()
return QDense(**layer_cfg)

def _create_other_layer(orig_layer):
# Clone other layers.
config = orig_layer.get_config()
return orig_layer.__class__.from_config(config)

def _create_quantized_bits_with_post_training_scale(q):
# Create a new quantized_bits instance with the fixed scale value.
if q is not None:
q_cfg = q.get_config()
q_cfg["post_training_scale"] = q.scale.numpy()
q = quantized_bits(**q_cfg)
return q

def _find_auto_po2_quantizer(layer):
# Find the auto_po2 quantizer in the layer. Note that we allow at
# most one auto_po2 quantizer in each layer due to the limitation of
# the current HW implementation.
num_auto_po2_quantizers = 0
auto_po2_quantizer = None
if hasattr(layer, "quantizers"):
for q in layer.quantizers:
if hasattr(q, "alpha") and q.alpha == "auto_po2":
num_auto_po2_quantizers += 1
auto_po2_quantizer = q
if num_auto_po2_quantizers > 1:
raise ValueError(
f"{layer.name} has more than one auto_po2 quantizer. "
"Please check if this is expected.")
else:
return auto_po2_quantizer

def _check_hw_weights_equal(hw_weights_1, hw_weights_2):
# Check if the hw weights extracted from the new model matches the
# original model.
for layer_name in hw_weights_2.keys():
for key in hw_weights_2[layer_name].keys():

val1 = hw_weights_2[layer_name][key]
val2 = hw_weights_1[layer_name][key]
if isinstance(val1, list):
for (v1, v2) in zip(val1, val2):
if not np.all(v1 == v2):
raise ValueError(
f"{layer_name}/{key}: No Match! v1={v1}, v2={v2}")
else:
if not np.all(val1 == val2):
raise ValueError(
f"{layer_name}/{key}: No Match! val1={val1}, val2={val2}")

# Load the original model with float weights.
# Note: weights will be quantized later in silicon flow by calling
# model_save_quantized_weights.
if orig_model is not None and orig_model_path is not None:
raise ValueError(
"Only one of orig_model and orig_model_path can be set.")
elif orig_model is None and orig_model_path is None:
raise ValueError(
"One of orig_model and orig_model_path must be set.")
elif orig_model_path is not None:
orig_model = load_qmodel(orig_model_path, compile=False)

# Quantize model weights and compute quantizer scale values.
quantized_model = tf.keras.models.clone_model(orig_model)
quantized_model.set_weights(orig_model.get_weights())
# In silicon flow, weight binary files are generated from hw weights.
orig_hw_weights = model_save_quantized_weights(
quantized_model)

# Create a new model with fixed scale quantizers.
x = inputs = tf.keras.Input(
shape=orig_model.input_shape[1:], name=orig_model.layers[0].name)
for layer in quantized_model.layers[1:]:
layer_class = layer.__class__.__name__
auto_po2_quantizer = _find_auto_po2_quantizer(layer)
auto_po2_quantizer_with_frozen_scale = (
_create_quantized_bits_with_post_training_scale(auto_po2_quantizer))
layer_cfg = layer.get_config()

# To be compatible with different python versions, we do not use
# match-case style here.
if layer_class == "QConv2D":
x = _create_qconv2d_layer(layer_cfg,
auto_po2_quantizer_with_frozen_scale)(x)
elif layer_class == "QDepthwiseConv2D":
x = _create_qdepthwise_conv2d_layer(
layer_cfg, auto_po2_quantizer_with_frozen_scale)(x)
elif layer_class == "QBatchNormalization":
x = _create_bn_layer(layer_cfg,
auto_po2_quantizer_with_frozen_scale)(x)
elif layer_class == "QDense":
x = _create_qdense_layer(layer_cfg,
auto_po2_quantizer_with_frozen_scale)(x)
else:
x = _create_other_layer(layer)(x)

new_model = tf.keras.Model(inputs, x)
# Set the weights of the new model to the original model (float weights).
new_model.set_weights(orig_model.get_weights())

# Check if the new model still has auto_po2 quantizer.
# This function only supports a colleciton of common layers that will use
# auto_po2 quantizers. For less common layers, we need to add extra support
# in the future.
for layer in new_model.layers:
q = _find_auto_po2_quantizer(layer)
if q is not None and q.post_training_scale is None:
raise ValueError(
f"{layer.name} in the new model still has auto_po2 quantizer with "
"adaptive scales. Please check if this is expected!")

new_hw_weights = None
if quantize_model_weights:
new_hw_weights = model_save_quantized_weights(new_model)
# Check if the hw weights extracted from the new model matches the original
# nima model.
_check_hw_weights_equal(orig_hw_weights, new_hw_weights)

return new_model, new_hw_weights
114 changes: 114 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import numpy as np
import pytest
import os
import tempfile
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

Expand All @@ -30,6 +32,8 @@
from qkeras.utils import is_TFOpLambda_layer
from qkeras.utils import find_bn_fusing_layer_pair
from qkeras.utils import add_bn_fusing_weights
from qkeras.utils import clone_model_and_freeze_auto_po2_scale
from qkeras.utils import load_qmodel


def create_quantized_network():
Expand Down Expand Up @@ -223,5 +227,115 @@ def test_find_bn_fusing_layer_pair():
assert np.all(d["fused_bias"] == np.array([0.09375, 0.65625]))


def create_test_model_for_scale_freezing(bias_quantizer):
def _create_simple_model(bias_quantizer):
x = x_in = tf.keras.Input((4, 4, 1), name="input")
x = QConv2D(
filters=4, kernel_size=2, strides=2,
kernel_quantizer=quantized_bits(4, 2, 1, alpha="auto_po2"),
bias_quantizer=quantized_bits(4, 2, 1),
use_bias=False,
name="conv")(x)
x = QDepthwiseConv2D(
kernel_size=2, strides=1,
depthwise_quantizer=quantized_bits(6, 3, 1, alpha="auto_po2"),
use_bias=False,
bias_quantizer=quantized_bits(4, 2, 1),
name="dw_conv")(x)
x = QBatchNormalization(
mean_quantizer=quantized_bits(4, 2, 1),
gamma_quantizer=None,
variance_quantizer=None,
beta_quantizer=quantized_bits(4, 0, 1),
inverse_quantizer=quantized_bits(8, 0, 1, alpha="auto_po2"),
name="bn")(x)

x = QActivation(activation=quantized_bits(4, 0), name="relu")(x)
x = tf.keras.layers.Flatten(name="flatten")(x)
x = QDense(units=2,
kernel_quantizer=quantized_bits(4, 2, 1, alpha="auto_po2"),
bias_quantizer=bias_quantizer, name="dense")(x)
model = tf.keras.Model(inputs=x_in, outputs=x)

return model

def _set_weights(model):
conv_w = [np.array(
[0.23, 2.76, 0.1, 0.33, 0.53, 0.16, 0.3, 1.7, -0.9,
1.43, 2.31, -0.2, -1.7, 0.39, -2.03, 1.79]).reshape(2, 2, 1, 4)]

dw_conv_w = [np.array([
0.03, 3.6, 2.1, 1.2, 0.13, 1.3, -0.3, 1.2, -0.7,
-10.3, 11.7, -0.92, -10.7, 0.59, -1.93, 2.8]).reshape((2, 2, 4, 1))]

bn_w = [np.array([0.28, 1.33, 2.27, 3.36]),
np.array([0.31, 0.1, 0.03, 4.26]),
np.array([0.89, -0.21, 1.97, 2.06]),
np.array([1.2, 0.9, 13.2, 10.9])]

dense_w = np.array(
[0.13, 0.66, 0.21, 0.23, 1.07, -0.79, 1.83, 1.81])
dense_w = [dense_w.reshape((4, 2)), np.array([-1.3, 0.7])]

model.get_layer("conv").set_weights(conv_w)
model.get_layer("dw_conv").set_weights(dw_conv_w)
model.get_layer("bn").set_weights(bn_w)
model.get_layer("dense").set_weights(dense_w)

orig_model = _create_simple_model(bias_quantizer)
_set_weights(orig_model)

return orig_model


def test_clone_model_and_freeze_auto_po2_scale():
"""Test clone_model_and_freeze_auto_po2_scale to work properly."""

orig_model = create_test_model_for_scale_freezing(quantized_bits(4, 2, 1))
_, new_hw = clone_model_and_freeze_auto_po2_scale(
orig_model, quantize_model_weights=True)

# Check if the new model's weights and scales are derived properly.
np.testing.assert_array_equal(
new_hw["conv"]["weights"][0],
np.array(
[[[[0.5, 6, 0, 0.5]], [[1, 0, 0.5, 3.5]]],
[[[-2., 3., 3.5, -0.5]], [[-3.5, 1., -3.5, 3.5]]]]))

np.testing.assert_array_equal(
new_hw["conv"]["scales"][0], np.array([[[[0.25, 0.5, 0.25, 0.25]]]]))

np.testing.assert_array_equal(
new_hw["dw_conv"]["weights"][0].numpy().flatten(),
np.array([
0., 14, 8, 4, 0, 6, -2, 4, -2, -42, 46, -4, -42, 2, -8, 12]))

np.testing.assert_array_equal(
new_hw["dense"]["scales"][0], np.array([[0.25, 0.25]]))


def test_clone_model_and_freeze_auto_po2_scale_serialization():
# Test if the cloned model can be saved and loaded properly.
orig_model = create_test_model_for_scale_freezing(quantized_bits(4, 2, 1))
new_model, _ = clone_model_and_freeze_auto_po2_scale(
orig_model, quantize_model_weights=True)

fd, fname = tempfile.mkstemp(".hdf5")
new_model.save(fname)
_ = load_qmodel(fname)
os.close(fd)
os.remove(fname)


def test_clone_model_and_freeze_auto_po2_scale_error():
orig_model = create_test_model_for_scale_freezing(
quantized_bits(4, 2, 1, alpha="auto_po2"))
# Test if the function raises an error when there are more than one
# auto_po2 quantizers in a layer.
with pytest.raises(ValueError):
clone_model_and_freeze_auto_po2_scale(
orig_model, quantize_model_weights=False)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit a1fec24

Please sign in to comment.