Skip to content

Commit

Permalink
starting to solving eval mode of textnetFadt model
Browse files Browse the repository at this point in the history
  • Loading branch information
nikokks committed Sep 9, 2023
1 parent dcd2ece commit a8ac914
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 34 deletions.
14 changes: 6 additions & 8 deletions doctr/models/classification/textnet_fast/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,14 @@ def __init__(
stage4: List[Dict[str, Union[int, List[int]]]],
include_top: bool = True,
num_classes: int = 1000,
input_shape: Optional[Tuple[int, int, int]] = None,
cfg: Optional[Dict[str, Any]] = None,
input_shape: Optional[Tuple[int, int, int]] = None,
) -> None:
first_conv = tf.keras.Sequential(
conv_sequence(out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2)
)
_layers = [first_conv]
_layers = [
tf.keras.Sequential(
conv_sequence(out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape),
)
]

for stage in [stage1, stage2, stage3, stage4]:
stage_ = tf.keras.Sequential([RepConvLayer(**params) for params in stage])
Expand All @@ -94,19 +95,16 @@ def eval(self, mode=False):
self = rep_model_convert(self)
self = fuse_module(self)
self.trainable = mode
return self

def train(self, mode=True):
self = unfuse_module(self)
self = rep_model_unconvert(self)
self.trainable = mode
return self

def test(self, mode=False):
self = rep_model_convert_deploy(self)
self = fuse_module(self)
self.trainable = mode
return self


def _textnetfast(
Expand Down
20 changes: 10 additions & 10 deletions doctr/models/modules/layers/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
use_bias=False,
input_shape=(None, None, in_channels),
),
layers.BatchNormalization()
]
)

self.main_bn = layers.BatchNormalization()

if kernel_size[1] != 1:
self.ver_conv = tf.keras.Sequential(
[
Expand All @@ -45,12 +44,12 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
use_bias=False,
input_shape=(None, None, in_channels),
),
layers.BatchNormalization()
]
)

self.ver_bn = layers.BatchNormalization()
else:
self.ver_conv, self.ver_bn = None, None
self.ver_conv = None

if kernel_size[0] != 1:
self.hor_conv = tf.keras.Sequential(
Expand All @@ -65,23 +64,24 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
use_bias=False,
input_shape=(None, None, in_channels),
),
layers.BatchNormalization()
]
)

self.hor_bn = layers.BatchNormalization()
else:
self.hor_conv, self.hor_bn = None, None
self.hor_conv = None

self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None

Check notice on line 74 in doctr/models/modules/layers/tensorflow.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/modules/layers/tensorflow.py#L74

Trailing whitespace
self.layers = [self.main_conv, self.ver_conv, self.hor_conv, self.rbr_identity, self.activation]

def call(
self,
x: tf.Tensor,
**kwargs: Any,
) -> tf.Tensor:
main_outputs = self.main_bn(self.main_conv(x, **kwargs), **kwargs)
vertical_outputs = self.ver_bn(self.ver_conv(x, **kwargs), **kwargs) if self.ver_conv is not None else 0
horizontal_outputs = self.hor_bn(self.hor_conv(x, **kwargs), **kwargs) if self.hor_conv is not None else 0
main_outputs = self.main_conv(x, **kwargs)
vertical_outputs = self.ver_conv(x, **kwargs) if self.ver_conv is not None else 0
horizontal_outputs = self.hor_conv(x, **kwargs) if self.hor_conv is not None else 0
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0

p = main_outputs + vertical_outputs
Expand Down
30 changes: 14 additions & 16 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import tf2onnx
from tensorflow.keras import Model, layers

from doctr.models.modules.layers.tensorflow import RepConvLayer
from doctr.utils.data import download_from_url

logging.getLogger("tensorflow").setLevel(logging.DEBUG)
Expand Down Expand Up @@ -188,13 +189,15 @@ def fuse_conv_bn(conv, bn):
only the mean and variance along channels are used, which exposes the opportunity
to fuse it with the preceding conv layers to save computations and simplify
network structures."""
print(dir(conv))
conv_weights, conv_biases = conv.get_weights()
bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights()

if conv_biases is None:

Check notice on line 193 in doctr/models/utils/tensorflow.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/utils/tensorflow.py#L193

Trailing whitespace
bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights()
weights = conv.get_weights()
if len(weights) == 1:
conv_weights = weights[0]
conv_biases = np.zeros_like(bn_running_mean)

else:
conv_weights, conv_biases = conv.get_weights()
epsilon = bn.epsilon
scale_factor = bn_weights / np.sqrt(bn_running_var + epsilon)

Expand All @@ -205,30 +208,25 @@ def fuse_conv_bn(conv, bn):
fused_conv_weights = conv_weights * scale_factor
fused_conv_biases = (conv_biases - bn_running_mean) * scale_factor.flatten() + bn_biases

# Setting the updated weights and biases in conv layer
conv.use_bias = True
conv.build(input_shape=conv.input_shape)
conv.set_weights([fused_conv_weights, fused_conv_biases])
conv.old_weight, conv.old_biais = conv.get_weights()
return conv
conv.old_weight, conv.old_biais = conv_weights, conv_biases


def fuse_module(model):
last_conv = None

for layer in model.layers:
print(layer)
if isinstance(layer, (tf.keras.layers.BatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization)):
if last_conv is None: # only fuse BN that is after Conv
continue
# Fused Conv and BN (You would need to define fuse_conv_bn_tf)
print('ok')
fuse_conv_bn(last_conv, layer)
# Here you'd need to replace the last_conv layer with fused_conv
# in the model, and replace the current layer with an identity layer.
# This is non-trivial in TensorFlow as Keras models are not as
# dynamically modifiable as PyTorch models.

elif isinstance(layer, tf.keras.layers.Conv2D):
last_conv = layer
else:
# Recursively apply to nested models
elif isinstance(layer, (tf.keras.Sequential, RepConvLayer)):
fuse_module(layer)
return model

Expand Down

0 comments on commit a8ac914

Please sign in to comment.