Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix tf.nn.{conv2d,convolution} substitution #1275

Merged
merged 2 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
Expand All @@ -30,7 +31,7 @@
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
ACTIVATION, LINEAR
ACTIVATION, LINEAR, DATA_FORMAT, GROUPS, CHANNELS_FORMAT_FIRST, CHANNELS_FORMAT_LAST


def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
Expand Down Expand Up @@ -136,42 +137,67 @@ def substitute(self,
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
if len(conv_func_node.op_call_args) > 0:
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in conv_func_node.op_call_kwargs:
strides = conv_func_node.op_call_kwargs[STRIDES]
if len(strides) == 4:
if strides[0] > 1 or strides[3] > 1:
# Non-standard strides -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[STRIDES] = strides[1:3]
else:
conv_fw_attr[STRIDES] = strides
if PADDING in conv_func_node.op_call_kwargs:
padding = conv_func_node.op_call_kwargs[PADDING]
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding
if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
dilations = conv_func_node.op_call_kwargs[DILATIONS]
if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
if dilations[0] > 1 or dilations[3] > 1:
# Non-standard dilations -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[DILATION_RATE] = dilations[1:3]
else:
conv_fw_attr[DILATION_RATE] = dilations

strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES)
if strides is None:
# Non-standard strides -> skip substitution.
return graph
conv_fw_attr[STRIDES] = strides

padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID'
if not isinstance(padding, str):
# Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
return graph # pragma: no cover
conv_fw_attr[PADDING] = padding

dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS)
if dilations is None:
# Non-standard dilations -> skip substitution.
return graph
conv_fw_attr[DILATION_RATE] = dilations

if b is None:
conv_fw_attr[USE_BIAS] = False
else:
weights[BIAS] = b

data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC')
conv_fw_attr[DATA_FORMAT] = {'NHWC': CHANNELS_FORMAT_LAST, 'NCHW': CHANNELS_FORMAT_FIRST}[data_format]

conv_fw_attr[GROUPS] = 1

_reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
weights, Conv2D, **_reuse_params)

replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
return graph

def _parse_tf_stride_dilation(self, node, key) -> Optional[Tuple[int, int]]:
"""
Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D).

Args:
node: node
key: param key

Returns:
Parsed value or None if non-standard.
"""
v = node.op_call_kwargs.get(key)
if v is None:
return 1, 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way you assume the defaults. why not return None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's intentional. None wouldn't do, we need to fill in an explicit default. This method is specific to tf stride & dilation

if isinstance(v, int):
return v, v
if len(v) == 1:
return v[0], v[0]
if len(v) == 4:
if v[0] > 1 and v[-1] > 1:
return None
else:
return v[1:3]
return tuple(v)


class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@

class ConvFuncSubstitutionsTest(BaseKerasFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test, input_shape=(32, 32, 3))

def get_tpc(self):
tp = generate_test_tp_model({'enable_weights_quantization': False,
'enable_activation_quantization': False})
Expand Down Expand Up @@ -67,6 +70,18 @@ def create_networks(self):
x = tf.nn.convolution(x, np.random.random((3, 3, 2, 4)).astype(np.float32),
[2, 1], padding='SAME')
x = tf.nn.bias_add(x, np.random.random((4,)).astype(np.float32))

# default values and various formats
x = tf.nn.conv2d(x, np.random.random((3, 3, 4, 8)), 1, 'VALID')
x = tf.nn.conv2d(x, np.random.random((3, 3, 8, 16)), strides=[1], padding='SAME', dilations=1)
x = tf.nn.conv2d(x, np.random.random((3, 3, 16, 8)), strides=[1, 1], padding='VALID', dilations=[1])
x = tf.nn.conv2d(x, filters=np.random.random((3, 3, 8, 4)), strides=[1, 1], padding='SAME', dilations=[1, 1])

x = tf.nn.convolution(x, np.random.random((3, 3, 4, 16)).astype(np.float32))
x = tf.nn.convolution(x, np.random.random((3, 3, 16, 32)).astype(np.float32), strides=[1], padding='SAME', dilations=1)
x = tf.nn.convolution(x, np.random.random((3, 3, 32, 8)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1])
x = tf.nn.convolution(x, filters=np.random.random((3, 3, 8, 4)).astype(np.float32), strides=[1, 1], padding='VALID', dilations=[1, 1])

return tf.keras.Model(inputs=_in, outputs=x)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
Expand All @@ -75,7 +90,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
cs = cosine_similarity(out_float.numpy(), out_quant.numpy())
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check: {cs}')

self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 4,
self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, Conv2D)) == 12,
"Not all conv functions were substituted.")
self.unit_test.assertTrue(len(get_layers_from_model_by_type(quantized_model, DepthwiseConv2D)) == 2,
"Not all dw-conv functions were substituted.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
# FIXME this doesn't test anything, the number of quantized convs in the network is exactly 0. Even if it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then why not remove it?

# looked at correct layers it hardly checks anything.
self.unit_test.assertTrue(len([l for l in quantized_model.layers if isinstance(l, KerasTrainableQuantizationWrapper) and isinstance(l.layer, layers.Conv2D)]) < len([l for l in float_model.layers if isinstance(l, layers.Conv2D)]), msg=f'fail number of layers should decrease!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')
Expand All @@ -75,6 +77,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2, msg=f'fail Bias should appear in weights!!')


class ThreeConv2DCollapsingTest(BaseConv2DCollapsingTest):
def __init__(self, unit_test):
super().__init__(unit_test)
Expand Down Expand Up @@ -107,9 +110,35 @@ def create_networks(self):

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
super().compare(quantized_model, float_model, input_x, quantization_info)
for layer in quantized_model.layers:
if type(layer) == layers.Conv2D:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 1)
for layer in convs:
self.unit_test.assertTrue(len(layer.weights) == 2,msg=f'fail Bias should appear in weights!!')


class FuncConvCollapsingTest(FourConv2DCollapsingTest):
def create_networks(self):
# Tests the combination of functional conv to Conv2D substitution with linear collapsing
# (in case of default values, tf layer doesn't contain these attributes, and they must be added explicitly
# to node's attributes dict, which is not covered by substitution test)
h, w, c = self.get_input_shapes()[0][1:]
inputs = layers.Input(shape=(h, w, c))
x = tf.nn.conv2d(inputs, tf.random.uniform((3, 3, c, 16)), 1, 'SAME')
x = tf.nn.convolution(x, tf.random.uniform((1, 1, 16, 8)))
x = tf.nn.relu(x)
x = tf.nn.convolution(x, tf.random.uniform((3, 3, 8, 32)))
y = tf.nn.conv2d(x, tf.random.uniform((1, 1, 32, 4)), 1, 'VALID')
return tf.keras.models.Model(inputs=inputs, outputs=y)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
convs = [l for l in quantized_model.layers if isinstance(l, layers.Conv2D)]
self.unit_test.assertTrue(len(convs) == 2)

y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!')
cs = cosine_similarity(y, y_hat)
self.unit_test.assertTrue(np.isclose(cs, 1), msg=f'fail cosine similarity check:{cs}')


class SixConv2DCollapsingTest(BaseConv2DCollapsingTest):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
from tests.keras_tests.feature_networks_tests.feature_networks.input_scaling_test import InputScalingDenseTest, \
InputScalingConvTest, InputScalingDWTest, InputScalingZeroPadTest
from tests.keras_tests.feature_networks_tests.feature_networks.linear_collapsing_test import TwoConv2DCollapsingTest, \
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest
ThreeConv2DCollapsingTest, FourConv2DCollapsingTest, SixConv2DCollapsingTest, Op2DAddConstCollapsingTest, \
FuncConvCollapsingTest
from tests.keras_tests.feature_networks_tests.feature_networks.lut_quantizer import LUTWeightsQuantizerTest, \
LUTActivationQuantizerTest
from tests.keras_tests.feature_networks_tests.feature_networks.manual_bit_selection import ManualBitWidthSelectionTest, \
Expand Down Expand Up @@ -605,6 +606,7 @@ def test_linear_collapsing(self):
FourConv2DCollapsingTest(self).run_test()
SixConv2DCollapsingTest(self).run_test()
Op2DAddConstCollapsingTest(self).run_test()
FuncConvCollapsingTest(self).run_test()

def test_const_quantization(self):
c = (np.ones((32, 32, 16)) + np.random.random((32, 32, 16))).astype(np.float32)
Expand Down
Loading