From 03c58348300b505ade75b35c2b1efd06d0d7560e Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 10:16:12 +0200
Subject: [PATCH 01/44] adding TextNet backbone : first try

---
 doctr/models/classification/fast/__init__.py  |   6 +
 doctr/models/classification/fast/pytorch.py   | 296 ++++++++++++++++++
 .../models/classification/fast/tensorflow.py  |   0
 3 files changed, 302 insertions(+)
 create mode 100644 doctr/models/classification/fast/__init__.py
 create mode 100644 doctr/models/classification/fast/pytorch.py
 create mode 100644 doctr/models/classification/fast/tensorflow.py

diff --git a/doctr/models/classification/fast/__init__.py b/doctr/models/classification/fast/__init__.py
new file mode 100644
index 0000000000..c7110f5669
--- /dev/null
+++ b/doctr/models/classification/fast/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+    from .tensorflow import *
+elif is_torch_available():
+    from .pytorch import *  # type: ignore[assignment]
diff --git a/doctr/models/classification/fast/pytorch.py b/doctr/models/classification/fast/pytorch.py
new file mode 100644
index 0000000000..1048c07b79
--- /dev/null
+++ b/doctr/models/classification/fast/pytorch.py
@@ -0,0 +1,296 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple
+import torch.nn as nn
+
+from doctr.datasets import VOCABS
+
+from ...utils import conv_sequence_pt, load_pretrained_params
+
+__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+    "textnet_tiny": {
+        #"mean": (0.694, 0.695, 0.693),
+        #"std": (0.299, 0.296, 0.301),
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnet_small": {
+        #"mean": (0.694, 0.695, 0.693),
+        #"std": (0.299, 0.296, 0.301),
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnet_base": {
+        #"mean": (0.694, 0.695, 0.693),
+        #"std": (0.299, 0.296, 0.301),
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+}
+
+
+
+
+
+class TextNet(nn.Module):
+    """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+     <https://arxiv.org/abs/2111.02394>>`_.
+
+    Args:
+        num_blocks: number of resnet block in each stage
+        output_channels: number of channels in each stage
+        stage_conv: whether to add a conv_sequence after each stage
+        stage_pooling: pooling to add after each stage (if None, no pooling)
+        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
+        stem_channels: number of output channels of the stem convolutions
+        attn_module: attention module to use in each stage
+        include_top: whether the classifier head should be instantiated
+        num_classes: number of output classes
+    """
+
+    def __init__(
+        self,
+        first_conv: Dict,
+        stage1: Dict[Any],
+        stage2: Dict[Any],
+        stage3: Dict[Any],
+        stage4: Dict[Any],
+        include_top: bool = True,
+        num_classes: int = 1000,
+        cfg: Optional[Dict[str, Any]] = None,
+    ) -> None:
+    
+        super(TextNet, self).__init__()
+        
+        self.first_conv = first_conv
+        self.stage1 = nn.ModuleList(stage1)
+        self.stage2 = nn.ModuleList(stage2)
+        self.stage3 = nn.ModuleList(stage3)
+        self.stage4 = nn.ModuleList(stage4)
+        
+        _layers: List[nn.Module]
+        
+        _layers = [self.first_conv, self.stage1, self.stage2, self.stage3, self.stage4]
+        
+        if include_top:
+            _layers.extend(
+                [
+                    nn.AdaptiveAvgPool2d(1),
+                    nn.Flatten(1),
+                    nn.Linear(output_channels[-1], num_classes, bias=True),
+                ]
+            )
+
+        super().__init__(*_layers)
+        self.cfg = cfg
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+
+
+def _textnet(
+    arch: str,
+    pretrained: bool,
+    arch_fn,
+    ignore_keys: Optional[List[str]] = None,
+    **kwargs: Any,
+) -> TextNet:
+    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+    _cfg = deepcopy(default_cfgs[arch])
+    _cfg["num_classes"] = kwargs["num_classes"]
+    _cfg["classes"] = kwargs["classes"]
+    kwargs.pop("classes")
+
+    # Build the model
+    model = arch_fn(**kwargs)
+    # Load pretrained parameters
+    if pretrained:
+        # The number of classes is not the same as the number of classes in the pretrained model =>
+        # remove the last layer weights
+        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+    model.cfg = _cfg
+
+    return model
+
+
+def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+    <https://arxiv.org/abs/2111.02394>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnet_tiny
+    >>> model = textnet_tiny(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnet(
+        "textnet_tiny",
+        pretrained,
+        TextNet,
+        first_conv = {"name": "ConvLayer", "kernel_size": 3, "stride": 2, "dilation": 1, "groups": 1, "bias": False, "has_shuffle": false, "in_channels": 3,
+                      "out_channels": 64, "use_bn": True, "act_func": "relu", "dropout_rate": 0, "ops_order": "weight_bn_act"},
+        stage1 = [ {"name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+                   
+        stage2 = [ {"name": "RepConvLayer", "in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},],
+                   
+        stage3 = [ {"name": "RepConvLayer", "in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},],
+                   
+        stage4 = [ {"name": "RepConvLayer", "in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   {"name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1}],
+                   
+        ignore_keys=["fc.weight", "fc.bias"],
+        **kwargs,
+    )
+    
+def textnet_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+    <https://arxiv.org/abs/2111.02394>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnet_small
+    >>> model = textnet_small(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnet(
+        "textnet_small",
+        pretrained,
+        TextNet,
+        first_conv = { "name": "ConvLayer", "kernel_size": 3, "stride": 2, "dilation": 1, "groups": 1, "bias": False, "has_shuffle": False, "in_channels": 3,
+                       "out_channels": 64, "use_bn": True, "act_func": "relu", "dropout_rate": 0, "ops_order": "weight_bn_act"},
+        stage1 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1}],
+                   
+        stage2 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+                   
+        stage3 = [ { "name": "RepConvLayer", "in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+                   
+        stage4 = [ { "name": "RepConvLayer", "in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},],
+        ignore_keys=["fc.weight", "fc.bias"],
+        **kwargs,
+    )
+    
+def textnet_base(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+    <https://arxiv.org/abs/2111.02394>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnet_base
+    >>> model = textnet_base(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnet(
+        "textnet_base",
+        pretrained,
+        TextNet,
+        first_conv = { "name": "ConvLayer", "kernel_size": 3, "stride": 2, "dilation": 1, "groups": 1, "bias": False, "has_shuffle": False, "in_channels": 3,
+                       "out_channels": 64, "use_bn": True, "act_func": "relu", "dropout_rate": 0, "ops_order": "weight_bn_act"},
+        stage1 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+
+        stage2 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+
+        stage3 = [ { "name": "RepConvLayer", "in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},],
+                   
+        stage4 = [ { "name": "RepConvLayer", "in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
+                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},]
+        ignore_keys=["fc.weight", "fc.bias"],
+        **kwargs,
+    )
diff --git a/doctr/models/classification/fast/tensorflow.py b/doctr/models/classification/fast/tensorflow.py
new file mode 100644
index 0000000000..e69de29bb2

From ceeca802c39d895dbb4b07aee1f67bc3390a3dfe Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 12:10:34 +0200
Subject: [PATCH 02/44] second try: uploading TextNet backbone

---
 doctr/models/classification/fast/pytorch.py | 176 +++++++++-----------
 doctr/models/modules/TextNet/__init__.py    |   6 +
 doctr/models/modules/TextNet/pytorch.py     |  50 ++++++
 3 files changed, 139 insertions(+), 93 deletions(-)
 create mode 100644 doctr/models/modules/TextNet/__init__.py
 create mode 100644 doctr/models/modules/TextNet/pytorch.py

diff --git a/doctr/models/classification/fast/pytorch.py b/doctr/models/classification/fast/pytorch.py
index 1048c07b79..6cf84fd936 100644
--- a/doctr/models/classification/fast/pytorch.py
+++ b/doctr/models/classification/fast/pytorch.py
@@ -10,7 +10,9 @@
 
 from doctr.datasets import VOCABS
 
-from ...utils import conv_sequence_pt, load_pretrained_params
+from ...utils import 	load_pretrained_params
+
+from doctr.models.utils.pytorch import conv_sequence_pt
 
 __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
 
@@ -40,9 +42,6 @@
 }
 
 
-
-
-
 class TextNet(nn.Module):
     """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
      <https://arxiv.org/abs/2111.02394>>`_.
@@ -61,7 +60,6 @@ class TextNet(nn.Module):
 
     def __init__(
         self,
-        first_conv: Dict,
         stage1: Dict[Any],
         stage2: Dict[Any],
         stage3: Dict[Any],
@@ -73,15 +71,13 @@ def __init__(
     
         super(TextNet, self).__init__()
         
-        self.first_conv = first_conv
-        self.stage1 = nn.ModuleList(stage1)
-        self.stage2 = nn.ModuleList(stage2)
-        self.stage3 = nn.ModuleList(stage3)
-        self.stage4 = nn.ModuleList(stage4)
-        
-        _layers: List[nn.Module]
-        
-        _layers = [self.first_conv, self.stage1, self.stage2, self.stage3, self.stage4]
+        _layers: List[nn.Module]        
+        self.first_conv = nn.ModuleList[conv_sequence(in_channels, out_channels, True, True, kernel_size=kernel_size, stride=stride)]
+
+        _layers.extend([self.first_conv ])
+        for stage in [stage1, stage2, stage3, stage4]:
+	        stage_ = nn.ModuleList([RepConvLayer(in_channels, out_channels, kernel_size, stride) for in_channels,out_channels,kernel_size,stride in stage])
+	        _layers.extend([stage_])
         
         if include_top:
             _layers.extend(
@@ -154,26 +150,24 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         "textnet_tiny",
         pretrained,
         TextNet,
-        first_conv = {"name": "ConvLayer", "kernel_size": 3, "stride": 2, "dilation": 1, "groups": 1, "bias": False, "has_shuffle": false, "in_channels": 3,
-                      "out_channels": 64, "use_bn": True, "act_func": "relu", "dropout_rate": 0, "ops_order": "weight_bn_act"},
-        stage1 = [ {"name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2,
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
                    
-        stage2 = [ {"name": "RepConvLayer", "in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},],
+        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}],
                    
-        stage3 = [ {"name": "RepConvLayer", "in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},],
+        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},],
                    
-        stage4 = [ {"name": "RepConvLayer", "in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   {"name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1}],
+        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1}],
                    
         ignore_keys=["fc.weight", "fc.bias"],
         **kwargs,
@@ -200,34 +194,32 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         "textnet_small",
         pretrained,
         TextNet,
-        first_conv = { "name": "ConvLayer", "kernel_size": 3, "stride": 2, "dilation": 1, "groups": 1, "bias": False, "has_shuffle": False, "in_channels": 3,
-                       "out_channels": 64, "use_bn": True, "act_func": "relu", "dropout_rate": 0, "ops_order": "weight_bn_act"},
-        stage1 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1}],
+        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}],
                    
-        stage2 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
                    
-        stage3 = [ { "name": "RepConvLayer", "in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
+        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},],
                    
-        stage4 = [ { "name": "RepConvLayer", "in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},],
+        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},],
         ignore_keys=["fc.weight", "fc.bias"],
         **kwargs,
     )
@@ -253,44 +245,42 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         "textnet_base",
         pretrained,
         TextNet,
-        first_conv = { "name": "ConvLayer", "kernel_size": 3, "stride": 2, "dilation": 1, "groups": 1, "bias": False, "has_shuffle": False, "in_channels": 3,
-                       "out_channels": 64, "use_bn": True, "act_func": "relu", "dropout_rate": 0, "ops_order": "weight_bn_act"},
-        stage1 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
-
-        stage2 = [ { "name": "RepConvLayer", "in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},],
-
-        stage3 = [ { "name": "RepConvLayer", "in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},],
+        stage1 = [ {"kernel_size": [3, 3], "stride": 1},
+                   {"kernel_size": [3, 3], "stride": 2},
+                   {"kernel_size": [3, 1], "stride": 1},
+                   {"kernel_size": [3, 3], "stride": 1},
+                   {"kernel_size": [3, 1], "stride": 1},
+                   {"kernel_size": [3, 3], "stride": 1},
+                   {"kernel_size": [3, 3], "stride": 1},
+                   {"kernel_size": [1, 3], "stride": 1},
+                   {"kernel_size": [3, 3], "stride": 1},
+                   {"kernel_size": [3, 3], "stride": 1}],
+
+        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
+
+        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},],
                    
-        stage4 = [ { "name": "RepConvLayer", "in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1, "dilation": 1, "groups": 1},
-                   { "name": "RepConvLayer", "in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1, "dilation": 1, "groups": 1},]
+        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},]
         ignore_keys=["fc.weight", "fc.bias"],
         **kwargs,
     )
diff --git a/doctr/models/modules/TextNet/__init__.py b/doctr/models/modules/TextNet/__init__.py
new file mode 100644
index 0000000000..c7110f5669
--- /dev/null
+++ b/doctr/models/modules/TextNet/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+    from .tensorflow import *
+elif is_torch_available():
+    from .pytorch import *  # type: ignore[assignment]
diff --git a/doctr/models/modules/TextNet/pytorch.py b/doctr/models/modules/TextNet/pytorch.py
new file mode 100644
index 0000000000..5ed5eb3263
--- /dev/null
+++ b/doctr/models/modules/TextNet/pytorch.py
@@ -0,0 +1,50 @@
+import torch.nn as nn
+
+class RepConvLayer(nn.Module):
+    def __init__(self, in_channels, out_channels, kernel_size, stride):
+        super(RepConvLayer, self).__init__()
+        self.ver_conv, self.ver_bn = None, None
+        self.hor_conv, self.hor_bn = None, None
+
+        self.activation = nn.ReLU(inplace=True)
+        self.main_conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=kernel_size[0] // 2,
+            bias=False,
+        )
+        self.main_bn = nn.BatchNorm2d(out_channels)
+
+        if kernel_size[1] != 1:
+            self.ver_conv = nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size=(kernel_size[0], 1),
+                stride=stride,
+                padding=(kernel_size[0] // 2, 0),
+                bias=False
+            )
+            self.ver_bn = nn.BatchNorm2d(out_channels)
+
+        if kernel_size[0] != 1:
+            self.hor_conv = nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size=(1, kernel_size[1]),
+                stride=stride,
+                padding=(0, kernel_size[1] // 2),
+                bias=False
+            )
+            self.hor_bn = nn.BatchNorm2d(out_channels)
+        self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels else None
+
+
+    def forward(self, input):
+        main_outputs = self.main_bn(self.main_conv(input))
+        vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
+        horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0
+        id_out = self.rbr_identity(input) if self.rbr_identity is not None else 0
+
+        return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

From 151e65f771e459c9529499f0d2012c526d31899f Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 12:25:43 +0200
Subject: [PATCH 03/44] correcting some syntax for textnet backbone

---
 doctr/models/classification/fast/pytorch.py | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/doctr/models/classification/fast/pytorch.py b/doctr/models/classification/fast/pytorch.py
index 6cf84fd936..b8317e54e1 100644
--- a/doctr/models/classification/fast/pytorch.py
+++ b/doctr/models/classification/fast/pytorch.py
@@ -13,10 +13,10 @@
 from ...utils import 	load_pretrained_params
 
 from doctr.models.utils.pytorch import conv_sequence_pt
+from doctr.models.modules.TextNet.pytorch import RepConvLayer
 
 __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
 
-
 default_cfgs: Dict[str, Dict[str, Any]] = {
     "textnet_tiny": {
         #"mean": (0.694, 0.695, 0.693),
@@ -76,7 +76,7 @@ def __init__(
 
         _layers.extend([self.first_conv ])
         for stage in [stage1, stage2, stage3, stage4]:
-	        stage_ = nn.ModuleList([RepConvLayer(in_channels, out_channels, kernel_size, stride) for in_channels,out_channels,kernel_size,stride in stage])
+	        stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
 	        _layers.extend([stage_])
         
         if include_top:
@@ -93,10 +93,10 @@ def __init__(
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight)
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
             elif isinstance(m, nn.BatchNorm2d):
-                m.weight.data.fill_(1)
-                m.bias.data.zero_()
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
 
 
 
@@ -151,13 +151,13 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         pretrained,
         TextNet,
         stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2,
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
                    {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
                    
         stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
                    {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
                    {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}],
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},],
                    
         stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
                    {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
@@ -280,7 +280,7 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TVResNet:
                    {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
                    {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
                    {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},]
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}],
         ignore_keys=["fc.weight", "fc.bias"],
         **kwargs,
     )

From 57d33f825366719577d95b31fd16a4d6b9de7e6e Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 13:30:40 +0200
Subject: [PATCH 04/44] changing import layer for textnet

---
 doctr/models/classification/fast/pytorch.py          | 2 +-
 doctr/models/modules/{TextNet => layers}/__init__.py | 0
 doctr/models/modules/{TextNet => layers}/pytorch.py  | 0
 3 files changed, 1 insertion(+), 1 deletion(-)
 rename doctr/models/modules/{TextNet => layers}/__init__.py (100%)
 rename doctr/models/modules/{TextNet => layers}/pytorch.py (100%)

diff --git a/doctr/models/classification/fast/pytorch.py b/doctr/models/classification/fast/pytorch.py
index b8317e54e1..a912eaf04f 100644
--- a/doctr/models/classification/fast/pytorch.py
+++ b/doctr/models/classification/fast/pytorch.py
@@ -13,7 +13,7 @@
 from ...utils import 	load_pretrained_params
 
 from doctr.models.utils.pytorch import conv_sequence_pt
-from doctr.models.modules.TextNet.pytorch import RepConvLayer
+from doctr.models.modules.layers.pytorch import RepConvLayer
 
 __all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
 
diff --git a/doctr/models/modules/TextNet/__init__.py b/doctr/models/modules/layers/__init__.py
similarity index 100%
rename from doctr/models/modules/TextNet/__init__.py
rename to doctr/models/modules/layers/__init__.py
diff --git a/doctr/models/modules/TextNet/pytorch.py b/doctr/models/modules/layers/pytorch.py
similarity index 100%
rename from doctr/models/modules/TextNet/pytorch.py
rename to doctr/models/modules/layers/pytorch.py

From dae10b22f23182a9f30e7170a1012538c8631331 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 15:16:37 +0200
Subject: [PATCH 05/44] renaming textnet to textnetfast

---
 .../{fast => textnetFast}/__init__.py         |  0
 .../{fast => textnetFast}/pytorch.py          | 58 +++++++++----------
 .../{fast => textnetFast}/tensorflow.py       |  0
 3 files changed, 29 insertions(+), 29 deletions(-)
 rename doctr/models/classification/{fast => textnetFast}/__init__.py (100%)
 rename doctr/models/classification/{fast => textnetFast}/pytorch.py (90%)
 rename doctr/models/classification/{fast => textnetFast}/tensorflow.py (100%)

diff --git a/doctr/models/classification/fast/__init__.py b/doctr/models/classification/textnetFast/__init__.py
similarity index 100%
rename from doctr/models/classification/fast/__init__.py
rename to doctr/models/classification/textnetFast/__init__.py
diff --git a/doctr/models/classification/fast/pytorch.py b/doctr/models/classification/textnetFast/pytorch.py
similarity index 90%
rename from doctr/models/classification/fast/pytorch.py
rename to doctr/models/classification/textnetFast/pytorch.py
index a912eaf04f..2df4176d48 100644
--- a/doctr/models/classification/fast/pytorch.py
+++ b/doctr/models/classification/textnetFast/pytorch.py
@@ -15,24 +15,24 @@
 from doctr.models.utils.pytorch import conv_sequence_pt
 from doctr.models.modules.layers.pytorch import RepConvLayer
 
-__all__ = ["textnet_tiny", "textnet_small", "textnet_base"]
+__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
 
 default_cfgs: Dict[str, Dict[str, Any]] = {
-    "textnet_tiny": {
+    "textnetfast_tiny": {
         #"mean": (0.694, 0.695, 0.693),
         #"std": (0.299, 0.296, 0.301),
         "input_shape": (3, 32, 32),
         "classes": list(VOCABS["french"]),
         "url": None,
     },
-    "textnet_small": {
+    "textnetfast_small": {
         #"mean": (0.694, 0.695, 0.693),
         #"std": (0.299, 0.296, 0.301),
         "input_shape": (3, 32, 32),
         "classes": list(VOCABS["french"]),
         "url": None,
     },
-    "textnet_base": {
+    "textnetfast_base": {
         #"mean": (0.694, 0.695, 0.693),
         #"std": (0.299, 0.296, 0.301),
         "input_shape": (3, 32, 32),
@@ -42,7 +42,7 @@
 }
 
 
-class TextNet(nn.Module):
+class TextNetFast(nn.Module):
     """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
      <https://arxiv.org/abs/2111.02394>>`_.
 
@@ -69,7 +69,7 @@ def __init__(
         cfg: Optional[Dict[str, Any]] = None,
     ) -> None:
     
-        super(TextNet, self).__init__()
+        super(TextNetFast, self).__init__()
         
         _layers: List[nn.Module]        
         self.first_conv = nn.ModuleList[conv_sequence(in_channels, out_channels, True, True, kernel_size=kernel_size, stride=stride)]
@@ -100,13 +100,13 @@ def __init__(
 
 
 
-def _textnet(
+def _textnetfast(
     arch: str,
     pretrained: bool,
     arch_fn,
     ignore_keys: Optional[List[str]] = None,
     **kwargs: Any,
-) -> TextNet:
+) -> TextNetFast:
     kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
     kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
 
@@ -129,13 +129,13 @@ def _textnet(
     return model
 
 
-def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
     """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
     <https://arxiv.org/abs/2111.02394>`_.
 
     >>> import torch
-    >>> from doctr.models import textnet_tiny
-    >>> model = textnet_tiny(pretrained=False)
+    >>> from doctr.models import textnetfast_tiny
+    >>> model = textnetfast_tiny(pretrained=False)
     >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
     >>> out = model(input_tensor)
 
@@ -146,10 +146,10 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         A TextNet model
     """
 
-    return _textnet(
-        "textnet_tiny",
+    return _textnetfast(
+        "textnetfast_tiny",
         pretrained,
-        TextNet,
+        TextNetFast,
         stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
                    {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
                    {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
@@ -173,13 +173,13 @@ def textnet_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         **kwargs,
     )
     
-def textnet_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
-    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+    """TextNetFast architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
     <https://arxiv.org/abs/2111.02394>`_.
 
     >>> import torch
-    >>> from doctr.models import textnet_small
-    >>> model = textnet_small(pretrained=False)
+    >>> from doctr.models import textnetfast_small
+    >>> model = textnetfast_small(pretrained=False)
     >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
     >>> out = model(input_tensor)
 
@@ -187,13 +187,13 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         pretrained: boolean, True if model is pretrained
 
     Returns:
-        A TextNet model
+        A TextNetFast model
     """
 
-    return _textnet(
-        "textnet_small",
+    return _textnetfast(
+        "textnetfast_small",
         pretrained,
-        TextNet,
+        TextNetFast,
         stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
                    {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}],
                    
@@ -224,13 +224,13 @@ def textnet_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         **kwargs,
     )
     
-def textnet_base(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
     """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
     <https://arxiv.org/abs/2111.02394>`_.
 
     >>> import torch
-    >>> from doctr.models import textnet_base
-    >>> model = textnet_base(pretrained=False)
+    >>> from doctr.models import textnetfast_base
+    >>> model = textnetfast_base(pretrained=False)
     >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
     >>> out = model(input_tensor)
 
@@ -238,13 +238,13 @@ def textnet_base(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         pretrained: boolean, True if model is pretrained
 
     Returns:
-        A TextNet model
+        A TextNetFast model
     """
 
-    return _textnet(
-        "textnet_base",
+    return _textnetfast(
+        "textnetfast_base",
         pretrained,
-        TextNet,
+        TextNetFast,
         stage1 = [ {"kernel_size": [3, 3], "stride": 1},
                    {"kernel_size": [3, 3], "stride": 2},
                    {"kernel_size": [3, 1], "stride": 1},
diff --git a/doctr/models/classification/fast/tensorflow.py b/doctr/models/classification/textnetFast/tensorflow.py
similarity index 100%
rename from doctr/models/classification/fast/tensorflow.py
rename to doctr/models/classification/textnetFast/tensorflow.py

From 2465696f0fe03feeff5966e09cb9531e059d001a Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 15:29:59 +0200
Subject: [PATCH 06/44] adding textnetfast test and ran it good

---
 tests/pytorch/test_models_classification_pt.py | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py
index 0ea879097a..5f2f7cc2ed 100644
--- a/tests/pytorch/test_models_classification_pt.py
+++ b/tests/pytorch/test_models_classification_pt.py
@@ -44,6 +44,9 @@ def _test_classification(model, input_shape, output_size, batch_size=2):
         ["vit_b", (3, 32, 32), (126,)],
         # Check that the interpolation of positional embeddings for vit models works correctly
         ["vit_s", (3, 64, 64), (126,)],
+        ["textnetfast_tiny",(3 ,32, 32), (126,)],
+        ["textnetfast_small",(3 ,32, 32), (126,)],
+        ["textnetfast_base",(3 ,32, 32), (126,)],
     ],
 )
 def test_classification_architectures(arch_name, input_shape, output_size):
@@ -125,6 +128,9 @@ def test_crop_orientation_model(mock_text_box):
         ["mobilenet_v3_large", (3, 32, 32), (126,)],
         ["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
         ["vit_b", (3, 32, 32), (126,)],
+        ["textnetfast_tiny",(3 ,32, 32), (126,)],
+        ["textnetfast_small",(3 ,32, 32), (126,)],
+        ["textnetfast_base",(3 ,32, 32), (126,)],
     ],
 )
 def test_models_onnx_export(arch_name, input_shape, output_size):

From ab5b5b3af686c3094d239e3d88b8b159ea6b4145 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 17:01:17 +0200
Subject: [PATCH 07/44] correcting some stuff to run TextNetFast : not working
 for the moment

---
 doctr/models/classification/__init__.py       |  1 +
 .../classification/textnetFast/pytorch.py     | 35 ++++++++++---------
 doctr/models/classification/zoo.py            |  3 ++
 doctr/models/modules/layers/pytorch.py        |  1 +
 4 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/doctr/models/classification/__init__.py b/doctr/models/classification/__init__.py
index 72e68b78df..c65a34abaf 100644
--- a/doctr/models/classification/__init__.py
+++ b/doctr/models/classification/__init__.py
@@ -4,3 +4,4 @@
 from .magc_resnet import *
 from .vit import *
 from .zoo import *
+from .textnetFast import *
diff --git a/doctr/models/classification/textnetFast/pytorch.py b/doctr/models/classification/textnetFast/pytorch.py
index 2df4176d48..2a3a22ede8 100644
--- a/doctr/models/classification/textnetFast/pytorch.py
+++ b/doctr/models/classification/textnetFast/pytorch.py
@@ -12,7 +12,7 @@
 
 from ...utils import 	load_pretrained_params
 
-from doctr.models.utils.pytorch import conv_sequence_pt
+from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
 from doctr.models.modules.layers.pytorch import RepConvLayer
 
 __all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
@@ -42,7 +42,7 @@
 }
 
 
-class TextNetFast(nn.Module):
+class TextNetFast(nn.Sequential):
     """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
      <https://arxiv.org/abs/2111.02394>>`_.
 
@@ -60,10 +60,10 @@ class TextNetFast(nn.Module):
 
     def __init__(
         self,
-        stage1: Dict[Any],
-        stage2: Dict[Any],
-        stage3: Dict[Any],
-        stage4: Dict[Any],
+        stage1: Dict,
+        stage2: Dict,
+        stage3: Dict,
+        stage4: Dict,
         include_top: bool = True,
         num_classes: int = 1000,
         cfg: Optional[Dict[str, Any]] = None,
@@ -71,23 +71,24 @@ def __init__(
     
         super(TextNetFast, self).__init__()
         
-        _layers: List[nn.Module]        
-        self.first_conv = nn.ModuleList[conv_sequence(in_channels, out_channels, True, True, kernel_size=kernel_size, stride=stride)]
-
-        _layers.extend([self.first_conv ])
+        _layers: List[nn.Module]    
+        self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
+        
+        _layers = [nn.ModuleList([*self.first_conv])]
+        
         for stage in [stage1, stage2, stage3, stage4]:
 	        stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
-	        _layers.extend([stage_])
+	        _layers.extend([*stage_])
         
         if include_top:
             _layers.extend(
                 [
                     nn.AdaptiveAvgPool2d(1),
                     nn.Flatten(1),
-                    nn.Linear(output_channels[-1], num_classes, bias=True),
+                    nn.Linear(64, num_classes, bias=True),
                 ]
             )
-
+        
         super().__init__(*_layers)
         self.cfg = cfg
 
@@ -97,8 +98,8 @@ def __init__(
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
-
-
+        print(self) 
+       
 
 def _textnetfast(
     arch: str,
@@ -129,7 +130,7 @@ def _textnetfast(
     return model
 
 
-def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
     """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
     <https://arxiv.org/abs/2111.02394>`_.
 
@@ -173,7 +174,7 @@ def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TVResNet:
         **kwargs,
     )
     
-def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TVResNet:
+def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
     """TextNetFast architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
     <https://arxiv.org/abs/2111.02394>`_.
 
diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py
index 9ec80a2619..3ae0273356 100644
--- a/doctr/models/classification/zoo.py
+++ b/doctr/models/classification/zoo.py
@@ -27,6 +27,9 @@
     "vgg16_bn_r",
     "vit_s",
     "vit_b",
+    "textnet_tiny",
+    "textnet_small",
+    "textnet_base",
 ]
 ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
 
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index 5ed5eb3263..97ae47a99e 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -42,6 +42,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride):
 
 
     def forward(self, input):
+    
         main_outputs = self.main_bn(self.main_conv(input))
         vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
         horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0

From 564f9ebc5966578573570f257798ded1b89a3f80 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 18:15:25 +0200
Subject: [PATCH 08/44] run test_classification_pytorch OK

---
 .../classification/textnetFast/pytorch.py     | 30 ++++++++--------
 doctr/models/classification/zoo.py            |  6 ++--
 doctr/models/modules/layers/pytorch.py        | 35 +++++++++++++------
 3 files changed, 42 insertions(+), 29 deletions(-)

diff --git a/doctr/models/classification/textnetFast/pytorch.py b/doctr/models/classification/textnetFast/pytorch.py
index 2a3a22ede8..078b029084 100644
--- a/doctr/models/classification/textnetFast/pytorch.py
+++ b/doctr/models/classification/textnetFast/pytorch.py
@@ -57,24 +57,23 @@ class TextNetFast(nn.Sequential):
         include_top: whether the classifier head should be instantiated
         num_classes: number of output classes
     """
-
     def __init__(
         self,
         stage1: Dict,
         stage2: Dict,
         stage3: Dict,
         stage4: Dict,
+        #input_shape: Tuple[int, int, int] = (32, 32, 3),        
         include_top: bool = True,
         num_classes: int = 1000,
         cfg: Optional[Dict[str, Any]] = None,
     ) -> None:
     
-        super(TextNetFast, self).__init__()
         
         _layers: List[nn.Module]    
         self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
         
-        _layers = [nn.ModuleList([*self.first_conv])]
+        _layers = [*self.first_conv]
         
         for stage in [stage1, stage2, stage3, stage4]:
 	        stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
@@ -85,21 +84,20 @@ def __init__(
                 [
                     nn.AdaptiveAvgPool2d(1),
                     nn.Flatten(1),
-                    nn.Linear(64, num_classes, bias=True),
+                    nn.Linear(512, num_classes, bias=True),
                 ]
             )
         
         super().__init__(*_layers)
         self.cfg = cfg
 
+
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
-        print(self) 
-       
 
 def _textnetfast(
     arch: str,
@@ -246,16 +244,16 @@ def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
         "textnetfast_base",
         pretrained,
         TextNetFast,
-        stage1 = [ {"kernel_size": [3, 3], "stride": 1},
-                   {"kernel_size": [3, 3], "stride": 2},
-                   {"kernel_size": [3, 1], "stride": 1},
-                   {"kernel_size": [3, 3], "stride": 1},
-                   {"kernel_size": [3, 1], "stride": 1},
-                   {"kernel_size": [3, 3], "stride": 1},
-                   {"kernel_size": [3, 3], "stride": 1},
-                   {"kernel_size": [1, 3], "stride": 1},
-                   {"kernel_size": [3, 3], "stride": 1},
-                   {"kernel_size": [3, 3], "stride": 1}],
+        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
 
         stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
                    {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py
index 3ae0273356..c573373d71 100644
--- a/doctr/models/classification/zoo.py
+++ b/doctr/models/classification/zoo.py
@@ -27,9 +27,9 @@
     "vgg16_bn_r",
     "vit_s",
     "vit_b",
-    "textnet_tiny",
-    "textnet_small",
-    "textnet_base",
+    "textnetfast_tiny",
+    "textnetfast_small",
+    "textnetfast_base",
 ]
 ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]
 
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index 97ae47a99e..62772e97a5 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -1,45 +1,60 @@
 import torch.nn as nn
 
 class RepConvLayer(nn.Module):
-    def __init__(self, in_channels, out_channels, kernel_size, stride):
+    def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1, groups=1):
         super(RepConvLayer, self).__init__()
-        self.ver_conv, self.ver_bn = None, None
-        self.hor_conv, self.hor_bn = None, None
-
+        
+        padding = (int(((kernel_size[0] - 1) * dilation) / 2),
+                   int(((kernel_size[1] - 1) * dilation) / 2))
+        
         self.activation = nn.ReLU(inplace=True)
         self.main_conv = nn.Conv2d(
             in_channels,
             out_channels,
             kernel_size=kernel_size,
             stride=stride,
-            padding=kernel_size[0] // 2,
+            padding=padding,
+            dilation=dilation, 
+            groups=groups,
             bias=False,
         )
         self.main_bn = nn.BatchNorm2d(out_channels)
 
+        ver_pad = (int(((kernel_size[0] - 1) * dilation) / 2), 0)
+        hor_pad = (0, int(((kernel_size[1] - 1) * dilation) / 2))
+        
+        
         if kernel_size[1] != 1:
             self.ver_conv = nn.Conv2d(
                 in_channels,
                 out_channels,
                 kernel_size=(kernel_size[0], 1),
                 stride=stride,
-                padding=(kernel_size[0] // 2, 0),
+                padding=ver_pad,
+                dilation=dilation, 
+                groups=groups,
                 bias=False
             )
             self.ver_bn = nn.BatchNorm2d(out_channels)
-
+        else:
+             self.ver_conv, self.ver_bn = None, None
+             
         if kernel_size[0] != 1:
             self.hor_conv = nn.Conv2d(
                 in_channels,
                 out_channels,
                 kernel_size=(1, kernel_size[1]),
                 stride=stride,
-                padding=(0, kernel_size[1] // 2),
+                padding=hor_pad,
+                dilation=dilation, 
+                groups=groups,
                 bias=False
             )
             self.hor_bn = nn.BatchNorm2d(out_channels)
-        self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels else None
-
+        else:
+            self.hor_conv, self.hor_bn = None, None
+        
+        self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
 
     def forward(self, input):
     

From 7f48b06c0a7e7412af80ef3d24acffc901e91f21 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 23:07:01 +0200
Subject: [PATCH 09/44] first commit of textnetfast model in tensorflow

---
 .../classification/textnetFast/tensorflow.py  | 276 ++++++++++++++++++
 doctr/models/modules/layers/tensorflow.py     |  72 +++++
 2 files changed, 348 insertions(+)
 create mode 100644 doctr/models/modules/layers/tensorflow.py

diff --git a/doctr/models/classification/textnetFast/tensorflow.py b/doctr/models/classification/textnetFast/tensorflow.py
index e69de29bb2..97fbab0f76 100644
--- a/doctr/models/classification/textnetFast/tensorflow.py
+++ b/doctr/models/classification/textnetFast/tensorflow.py
@@ -0,0 +1,276 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import layers
+from tensorflow.keras.applications import ResNet50
+from tensorflow.keras.models import Sequential
+
+from doctr.datasets import VOCABS
+
+from ...utils import conv_sequence, load_pretrained_params
+
+from doctr.models.modules.layers.tensorflow import RepConvLayer
+
+__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+    "textnetfast_tiny": {
+        #"mean": (0.694, 0.695, 0.693),
+        #"std": (0.299, 0.296, 0.301),
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_small": {
+        #"mean": (0.694, 0.695, 0.693),
+        #"std": (0.299, 0.296, 0.301),
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_base": {
+        #"mean": (0.694, 0.695, 0.693),
+        #"std": (0.299, 0.296, 0.301),
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+}
+
+
+class TextNetFast(Sequential):
+    """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
+     <https://arxiv.org/abs/2111.02394>>`_.
+
+    Args:
+        num_blocks: number of resnet block in each stage
+        output_channels: number of channels in each stage
+        stage_conv: whether to add a conv_sequence after each stage
+        stage_pooling: pooling to add after each stage (if None, no pooling)
+        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
+        stem_channels: number of output channels of the stem convolutions
+        attn_module: attention module to use in each stage
+        include_top: whether the classifier head should be instantiated
+        num_classes: number of output classes
+    """
+    def __init__(
+        self,
+        stage1: Dict,
+        stage2: Dict,
+        stage3: Dict,
+        stage4: Dict,
+        #input_shape: Tuple[int, int, int] = (32, 32, 3),        
+        include_top: bool = True,
+        num_classes: int = 1000,
+        cfg: Optional[Dict[str, Any]] = None,
+    ) -> None:
+    
+        
+        _layers = [*conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)]
+        
+        for stage in [stage1, stage2, stage3, stage4]:
+	        _layers.append(RepConvLayer(**params) for params in stage])
+        
+        if include_top:
+            _layers.extend([
+                layers.GlobalAveragePooling2D(),
+                layers.Flatten(),
+                layers.Dense(num_classes, activation=None)
+            ])
+        
+        super().__init__(_layers)
+        self.cfg = cfg
+
+
+def _textnetfast(
+    arch: str,
+    pretrained: bool,
+    arch_fn,
+    ignore_keys: Optional[List[str]] = None,
+    **kwargs: Any,
+) -> TextNetFast:
+    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+    _cfg = deepcopy(default_cfgs[arch])
+    _cfg["num_classes"] = kwargs["num_classes"]
+    _cfg["classes"] = kwargs["classes"]
+    kwargs.pop("classes")
+
+    # Build the model
+    model = arch_fn(**kwargs)
+    # Load pretrained parameters
+    if pretrained:
+        # The number of classes is not the same as the number of classes in the pretrained model =>
+        # remove the last layer weights
+        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+    model.cfg = _cfg
+
+    return model
+
+
+def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+    <https://arxiv.org/abs/2111.02394>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_tiny
+    >>> model = textnetfast_tiny(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnetfast(
+        "textnetfast_tiny",
+        pretrained,
+        TextNetFast,
+        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
+                   
+        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},],
+                   
+        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},],
+                   
+        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1}],
+                   
+        ignore_keys=["fc.weight", "fc.bias"],
+        **kwargs,
+    )
+    
+
+def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """TextNetFast architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+    <https://arxiv.org/abs/2111.02394>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_small
+    >>> model = textnetfast_small(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_small",
+        pretrained,
+        TextNetFast,
+        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}],
+                   
+        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
+                   
+        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},],
+                   
+        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},],
+        ignore_keys=["fc.weight", "fc.bias"],
+        **kwargs,
+    )
+
+
+def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
+    <https://arxiv.org/abs/2111.02394>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_base
+    >>> model = textnetfast_base(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_base",
+        pretrained,
+        TextNetFast,
+        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
+
+        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
+
+        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},],
+                   
+        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}],
+        ignore_keys=["fc.weight", "fc.bias"],
+        **kwargs,
+    )
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
new file mode 100644
index 0000000000..eca30f75ec
--- /dev/null
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -0,0 +1,72 @@
+import math
+from typing import Any, Tuple
+
+import tensorflow as tf
+from tensorflow.keras import layers
+
+
+class RepConvLayer(layers.Layer):
+    def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1, groups=1):
+        super(RepConvLayer, self).__init__()
+        
+        padding = (int(((kernel_size[0] - 1) * dilation) / 2),
+                   int(((kernel_size[1] - 1) * dilation) / 2))
+        
+        self.activation = layers.ReLU()
+        self.main_conv = layers.Conv2D(
+            filters=out_channels,
+            kernel_size=kernel_size,
+            strides=stride,
+            padding=padding,
+            dilation_rate=dilation,
+            groups=groups,
+            use_bias=False,
+            input_shape=(None, None, in_channels) 
+        )
+
+        self.main_bn = layers.BatchNormalization()
+
+        ver_pad = (int(((kernel_size[0] - 1) * dilation) / 2), 0)
+        hor_pad = (0, int(((kernel_size[1] - 1) * dilation) / 2))
+        
+        
+        if kernel_size[1] != 1:
+            self.ver_conv = layers.Conv2D(
+            filters=out_channels,
+            kernel_size=(kernel_size[0], 1),
+            strides=(stride, 1),
+            padding='valid',  
+            dilation_rate=(dilation, 1),
+            groups=groups,
+            use_bias=False,
+            input_shape=(None, None, in_channels) 
+            )
+            self.ver_bn = layers.BatchNormalization()
+        else:
+             self.ver_conv, self.ver_bn = None, None
+             
+        if kernel_size[0] != 1:
+            self.hor_conv = layers.Conv2D(
+                filters=out_channels,
+                kernel_size=(1, kernel_size[1]),
+                strides=stride,
+                padding='valid',  # TensorFlow utilise 'valid' pour l'équivalent de 'same' de PyTorch
+                dilation_rate=dilation,
+                groups=groups,
+                use_bias=False,
+                input_shape=(None, None, in_channels)  # Spécifiez la forme de l'entrée ici
+                )
+            self.hor_bn = layers.BatchNormalization()
+        else:
+            self.hor_conv, self.hor_bn = None, None
+        
+        self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
+
+    def forward(self, input):
+    
+        main_outputs = self.main_bn(self.main_conv(input))
+        vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
+        horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0
+        id_out = self.rbr_identity(input) if self.rbr_identity is not None else 0
+
+        return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

From 346bcd4f5381979a5db20992eb6fdc9441c0caa6 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 29 Aug 2023 23:45:46 +0200
Subject: [PATCH 10/44] ending implementing tensorflow textnet classification
 model => go for test-tf

---
 .../classification/textnetFast/tensorflow.py  | 13 ++++++------
 doctr/models/utils/tensorflow.py              |  2 +-
 .../test_models_classification_tf.py          | 21 +++++++++++++++++++
 3 files changed, 29 insertions(+), 7 deletions(-)

diff --git a/doctr/models/classification/textnetFast/tensorflow.py b/doctr/models/classification/textnetFast/tensorflow.py
index 97fbab0f76..8e735abc62 100644
--- a/doctr/models/classification/textnetFast/tensorflow.py
+++ b/doctr/models/classification/textnetFast/tensorflow.py
@@ -14,7 +14,7 @@
 from doctr.datasets import VOCABS
 
 from ...utils import conv_sequence, load_pretrained_params
-
+from doctr.models.utils.tensorflow import conv_sequence
 from doctr.models.modules.layers.tensorflow import RepConvLayer
 
 __all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
@@ -64,18 +64,19 @@ def __init__(
         stage1: Dict,
         stage2: Dict,
         stage3: Dict,
-        stage4: Dict,
-        #input_shape: Tuple[int, int, int] = (32, 32, 3),        
+        stage4: Dict,     
         include_top: bool = True,
         num_classes: int = 1000,
+        input_shape: Tuple[int, int, int] = (32, 32, 3),
         cfg: Optional[Dict[str, Any]] = None,
     ) -> None:
     
         
-        _layers = [*conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)]
+        _layers = [*conv_sequence(input_shape=input_shape, out_channels=64, activation='relu', bn=True, kernel_size=3, strides=2)]
         
         for stage in [stage1, stage2, stage3, stage4]:
-	        _layers.append(RepConvLayer(**params) for params in stage])
+	        stage_ = Sequential([RepConvLayer(**params) for params in stage])
+	        _layers.extend([stage_])
         
         if include_top:
             _layers.extend([
@@ -83,7 +84,7 @@ def __init__(
                 layers.Flatten(),
                 layers.Dense(num_classes, activation=None)
             ])
-        
+            
         super().__init__(_layers)
         self.cfg = cfg
 
diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py
index 8490c09f11..346b320259 100644
--- a/doctr/models/utils/tensorflow.py
+++ b/doctr/models/utils/tensorflow.py
@@ -66,7 +66,7 @@ def conv_sequence(
     bn: bool = False,
     padding: str = "same",
     kernel_initializer: str = "he_normal",
-    **kwargs: Any,
+    **kwargs: Any,    
 ) -> List[layers.Layer]:
     """Builds a convolutional-based layer sequence
 
diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py
index 25d3ca5ad0..84eefb9f9d 100644
--- a/tests/tensorflow/test_models_classification_tf.py
+++ b/tests/tensorflow/test_models_classification_tf.py
@@ -29,6 +29,9 @@
         ["mobilenet_v3_large", (32, 32, 3), (126,)],
         ["vit_s", (32, 32, 3), (126,)],
         ["vit_b", (32, 32, 3), (126,)],
+        ["textnetfast_tiny", (32, 32, 3), (126,)],
+        ["textnetfast_small", (32, 32, 3), (126,)],
+        ["textnetfast_base", (32, 32, 3), (126,)],
     ],
 )
 def test_classification_architectures(arch_name, input_shape, output_size):
@@ -136,6 +139,24 @@ def test_crop_orientation_model(mock_text_box):
             (126,),
             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
         ),
+        pytest.param(
+             "textnetfast_tiny", 
+             (32, 32, 3), 
+             (126,),
+             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
+        ),
+        pytest.param(
+             "textnetfast_small", 
+             (32, 32, 3), 
+             (126,),
+             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
+        ),
+        pytest.param(
+             "textnetfast_base", 
+             (32, 32, 3), 
+             (126,),
+             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
+        ),
     ],
 )
 def test_models_onnx_export(arch_name, input_shape, output_size):

From eb9299eb34e21cd9ffcdc2a21f24a175f89b8158 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 13:08:05 +0200
Subject: [PATCH 11/44] some changes + make style + make quality

---
 doctr/models/classification/__init__.py       |   2 +-
 .../classification/textnetFast/pytorch.py     | 285 -----------------
 .../classification/textnetFast/tensorflow.py  | 277 ----------------
 .../{textnetFast => textnet_fast}/__init__.py |   0
 .../classification/textnet_fast/pytorch.py    | 297 ++++++++++++++++++
 .../classification/textnet_fast/tensorflow.py | 288 +++++++++++++++++
 doctr/models/modules/layers/pytorch.py        |  30 +-
 doctr/models/modules/layers/tensorflow.py     |  51 ++-
 doctr/models/utils/tensorflow.py              |   2 +-
 .../pytorch/test_models_classification_pt.py  |  12 +-
 .../test_models_classification_tf.py          |  33 +-
 11 files changed, 644 insertions(+), 633 deletions(-)
 delete mode 100644 doctr/models/classification/textnetFast/pytorch.py
 delete mode 100644 doctr/models/classification/textnetFast/tensorflow.py
 rename doctr/models/classification/{textnetFast => textnet_fast}/__init__.py (100%)
 create mode 100644 doctr/models/classification/textnet_fast/pytorch.py
 create mode 100644 doctr/models/classification/textnet_fast/tensorflow.py

diff --git a/doctr/models/classification/__init__.py b/doctr/models/classification/__init__.py
index c65a34abaf..e1b303ef2f 100644
--- a/doctr/models/classification/__init__.py
+++ b/doctr/models/classification/__init__.py
@@ -4,4 +4,4 @@
 from .magc_resnet import *
 from .vit import *
 from .zoo import *
-from .textnetFast import *
+from .textnet_fast import *
diff --git a/doctr/models/classification/textnetFast/pytorch.py b/doctr/models/classification/textnetFast/pytorch.py
deleted file mode 100644
index 078b029084..0000000000
--- a/doctr/models/classification/textnetFast/pytorch.py
+++ /dev/null
@@ -1,285 +0,0 @@
-# Copyright (C) 2021-2023, Mindee.
-
-# This program is licensed under the Apache License 2.0.
-# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
-
-
-from copy import deepcopy
-from typing import Any, Callable, Dict, List, Optional, Tuple
-import torch.nn as nn
-
-from doctr.datasets import VOCABS
-
-from ...utils import 	load_pretrained_params
-
-from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-from doctr.models.modules.layers.pytorch import RepConvLayer
-
-__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
-
-default_cfgs: Dict[str, Dict[str, Any]] = {
-    "textnetfast_tiny": {
-        #"mean": (0.694, 0.695, 0.693),
-        #"std": (0.299, 0.296, 0.301),
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_small": {
-        #"mean": (0.694, 0.695, 0.693),
-        #"std": (0.299, 0.296, 0.301),
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_base": {
-        #"mean": (0.694, 0.695, 0.693),
-        #"std": (0.299, 0.296, 0.301),
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-}
-
-
-class TextNetFast(nn.Sequential):
-    """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
-     <https://arxiv.org/abs/2111.02394>>`_.
-
-    Args:
-        num_blocks: number of resnet block in each stage
-        output_channels: number of channels in each stage
-        stage_conv: whether to add a conv_sequence after each stage
-        stage_pooling: pooling to add after each stage (if None, no pooling)
-        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
-        stem_channels: number of output channels of the stem convolutions
-        attn_module: attention module to use in each stage
-        include_top: whether the classifier head should be instantiated
-        num_classes: number of output classes
-    """
-    def __init__(
-        self,
-        stage1: Dict,
-        stage2: Dict,
-        stage3: Dict,
-        stage4: Dict,
-        #input_shape: Tuple[int, int, int] = (32, 32, 3),        
-        include_top: bool = True,
-        num_classes: int = 1000,
-        cfg: Optional[Dict[str, Any]] = None,
-    ) -> None:
-    
-        
-        _layers: List[nn.Module]    
-        self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
-        
-        _layers = [*self.first_conv]
-        
-        for stage in [stage1, stage2, stage3, stage4]:
-	        stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
-	        _layers.extend([*stage_])
-        
-        if include_top:
-            _layers.extend(
-                [
-                    nn.AdaptiveAvgPool2d(1),
-                    nn.Flatten(1),
-                    nn.Linear(512, num_classes, bias=True),
-                ]
-            )
-        
-        super().__init__(*_layers)
-        self.cfg = cfg
-
-
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-            elif isinstance(m, nn.BatchNorm2d):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
-
-def _textnetfast(
-    arch: str,
-    pretrained: bool,
-    arch_fn,
-    ignore_keys: Optional[List[str]] = None,
-    **kwargs: Any,
-) -> TextNetFast:
-    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
-    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
-
-    _cfg = deepcopy(default_cfgs[arch])
-    _cfg["num_classes"] = kwargs["num_classes"]
-    _cfg["classes"] = kwargs["classes"]
-    kwargs.pop("classes")
-
-    # Build the model
-    model = arch_fn(**kwargs)
-    # Load pretrained parameters
-    if pretrained:
-        # The number of classes is not the same as the number of classes in the pretrained model =>
-        # remove the last layer weights
-        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
-        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
-
-    model.cfg = _cfg
-
-    return model
-
-
-def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
-    <https://arxiv.org/abs/2111.02394>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_tiny
-    >>> model = textnetfast_tiny(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNet model
-    """
-
-    return _textnetfast(
-        "textnetfast_tiny",
-        pretrained,
-        TextNetFast,
-        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
-                   
-        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},],
-                   
-        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},],
-                   
-        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1}],
-                   
-        ignore_keys=["fc.weight", "fc.bias"],
-        **kwargs,
-    )
-    
-def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """TextNetFast architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
-    <https://arxiv.org/abs/2111.02394>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_small
-    >>> model = textnetfast_small(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_small",
-        pretrained,
-        TextNetFast,
-        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}],
-                   
-        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
-                   
-        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},],
-                   
-        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},],
-        ignore_keys=["fc.weight", "fc.bias"],
-        **kwargs,
-    )
-    
-def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
-    <https://arxiv.org/abs/2111.02394>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_base
-    >>> model = textnetfast_base(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_base",
-        pretrained,
-        TextNetFast,
-        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
-
-        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
-
-        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},],
-                   
-        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}],
-        ignore_keys=["fc.weight", "fc.bias"],
-        **kwargs,
-    )
diff --git a/doctr/models/classification/textnetFast/tensorflow.py b/doctr/models/classification/textnetFast/tensorflow.py
deleted file mode 100644
index 8e735abc62..0000000000
--- a/doctr/models/classification/textnetFast/tensorflow.py
+++ /dev/null
@@ -1,277 +0,0 @@
-# Copyright (C) 2021-2023, Mindee.
-
-# This program is licensed under the Apache License 2.0.
-# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
-
-from copy import deepcopy
-from typing import Any, Callable, Dict, List, Optional, Tuple
-
-import tensorflow as tf
-from tensorflow.keras import layers
-from tensorflow.keras.applications import ResNet50
-from tensorflow.keras.models import Sequential
-
-from doctr.datasets import VOCABS
-
-from ...utils import conv_sequence, load_pretrained_params
-from doctr.models.utils.tensorflow import conv_sequence
-from doctr.models.modules.layers.tensorflow import RepConvLayer
-
-__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
-
-default_cfgs: Dict[str, Dict[str, Any]] = {
-    "textnetfast_tiny": {
-        #"mean": (0.694, 0.695, 0.693),
-        #"std": (0.299, 0.296, 0.301),
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_small": {
-        #"mean": (0.694, 0.695, 0.693),
-        #"std": (0.299, 0.296, 0.301),
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_base": {
-        #"mean": (0.694, 0.695, 0.693),
-        #"std": (0.299, 0.296, 0.301),
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-}
-
-
-class TextNetFast(Sequential):
-    """Implements a TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
-     <https://arxiv.org/abs/2111.02394>>`_.
-
-    Args:
-        num_blocks: number of resnet block in each stage
-        output_channels: number of channels in each stage
-        stage_conv: whether to add a conv_sequence after each stage
-        stage_pooling: pooling to add after each stage (if None, no pooling)
-        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
-        stem_channels: number of output channels of the stem convolutions
-        attn_module: attention module to use in each stage
-        include_top: whether the classifier head should be instantiated
-        num_classes: number of output classes
-    """
-    def __init__(
-        self,
-        stage1: Dict,
-        stage2: Dict,
-        stage3: Dict,
-        stage4: Dict,     
-        include_top: bool = True,
-        num_classes: int = 1000,
-        input_shape: Tuple[int, int, int] = (32, 32, 3),
-        cfg: Optional[Dict[str, Any]] = None,
-    ) -> None:
-    
-        
-        _layers = [*conv_sequence(input_shape=input_shape, out_channels=64, activation='relu', bn=True, kernel_size=3, strides=2)]
-        
-        for stage in [stage1, stage2, stage3, stage4]:
-	        stage_ = Sequential([RepConvLayer(**params) for params in stage])
-	        _layers.extend([stage_])
-        
-        if include_top:
-            _layers.extend([
-                layers.GlobalAveragePooling2D(),
-                layers.Flatten(),
-                layers.Dense(num_classes, activation=None)
-            ])
-            
-        super().__init__(_layers)
-        self.cfg = cfg
-
-
-def _textnetfast(
-    arch: str,
-    pretrained: bool,
-    arch_fn,
-    ignore_keys: Optional[List[str]] = None,
-    **kwargs: Any,
-) -> TextNetFast:
-    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
-    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
-
-    _cfg = deepcopy(default_cfgs[arch])
-    _cfg["num_classes"] = kwargs["num_classes"]
-    _cfg["classes"] = kwargs["classes"]
-    kwargs.pop("classes")
-
-    # Build the model
-    model = arch_fn(**kwargs)
-    # Load pretrained parameters
-    if pretrained:
-        # The number of classes is not the same as the number of classes in the pretrained model =>
-        # remove the last layer weights
-        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
-        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
-
-    model.cfg = _cfg
-
-    return model
-
-
-def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
-    <https://arxiv.org/abs/2111.02394>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_tiny
-    >>> model = textnetfast_tiny(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNet model
-    """
-
-    return _textnetfast(
-        "textnetfast_tiny",
-        pretrained,
-        TextNetFast,
-        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
-                   
-        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},],
-                   
-        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},],
-                   
-        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1}],
-                   
-        ignore_keys=["fc.weight", "fc.bias"],
-        **kwargs,
-    )
-    
-
-def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """TextNetFast architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
-    <https://arxiv.org/abs/2111.02394>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_small
-    >>> model = textnetfast_small(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_small",
-        pretrained,
-        TextNetFast,
-        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}],
-                   
-        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
-                   
-        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},],
-                   
-        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},],
-        ignore_keys=["fc.weight", "fc.bias"],
-        **kwargs,
-    )
-
-
-def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """TextNet architecture as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation",
-    <https://arxiv.org/abs/2111.02394>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_base
-    >>> model = textnetfast_base(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_base",
-        pretrained,
-        TextNetFast,
-        stage1 = [ {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}],
-
-        stage2 = [ {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},],
-
-        stage3 = [ {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-                   {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},],
-                   
-        stage4 = [ {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-                   {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}],
-        ignore_keys=["fc.weight", "fc.bias"],
-        **kwargs,
-    )
diff --git a/doctr/models/classification/textnetFast/__init__.py b/doctr/models/classification/textnet_fast/__init__.py
similarity index 100%
rename from doctr/models/classification/textnetFast/__init__.py
rename to doctr/models/classification/textnet_fast/__init__.py
diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
new file mode 100644
index 0000000000..2a262b5598
--- /dev/null
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -0,0 +1,297 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional
+
+import torch.nn as nn
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.layers.pytorch import RepConvLayer
+from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
+
+from ...utils import load_pretrained_params
+
+__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+    "textnetfast_tiny": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_small": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_base": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+}
+
+
+class TextNetFast(nn.Sequential):
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    Args:
+        num_blocks: number of resnet block in each stage
+        output_channels: number of channels in each stage
+        stage_conv: whether to add a conv_sequence after each stage
+        stage_pooling: pooling to add after each stage (if None, no pooling)
+        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
+        stem_channels: number of output channels of the stem convolutions
+        attn_module: attention module to use in each stage
+        include_top: whether the classifier head should be instantiated
+        num_classes: number of output classes
+    """
+
+    def __init__(
+        self,
+        stage1: Dict,
+        stage2: Dict,
+        stage3: Dict,
+        stage4: Dict,
+        include_top: bool = True,
+        num_classes: int = 1000,
+        cfg: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        _layers: List[nn.Module]
+        self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
+
+        _layers = [*self.first_conv]
+
+        for stage in [stage1, stage2, stage3, stage4]:
+            stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
+            _layers.extend([*stage_])
+
+        if include_top:
+            _layers.extend(
+                [
+                    nn.AdaptiveAvgPool2d(1),
+                    nn.Flatten(1),
+                    nn.Linear(512, num_classes, bias=True),
+                ]
+            )
+
+        super().__init__(*_layers)
+        self.cfg = cfg
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+
+def _textnetfast(
+    arch: str,
+    pretrained: bool,
+    arch_fn,
+    ignore_keys: Optional[List[str]] = None,
+    **kwargs: Any,
+) -> TextNetFast:
+    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+    _cfg = deepcopy(default_cfgs[arch])
+    _cfg["num_classes"] = kwargs["num_classes"]
+    _cfg["classes"] = kwargs["classes"]
+    kwargs.pop("classes")
+
+    # Build the model
+    model = arch_fn(**kwargs)
+    # Load pretrained parameters
+    if pretrained:
+        # The number of classes is not the same as the number of classes in the pretrained model =>
+        # remove the last layer weights
+        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+    model.cfg = _cfg
+
+    return model
+
+
+def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_tiny
+    >>> model = textnetfast_tiny(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnetfast(
+        "textnetfast_tiny",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage2=[
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+        ],
+        stage3=[
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+        ],
+        stage4=[
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
+        ],
+        ignore_keys=["10.weight", "10.bias"],
+        **kwargs,
+    )
+
+
+def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_small
+    >>> model = textnetfast_small(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_small",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+        ],
+        stage2=[
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage3=[
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage4=[
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+        ],
+        ignore_keys=["10.weight", "10.bias"],
+        **kwargs,
+    )
+
+
+def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_base
+    >>> model = textnetfast_base(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_base",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage2=[
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage3=[
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+        ],
+        stage4=[
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+        ],
+        ignore_keys=["10.weight", "10.bias"],
+        **kwargs,
+    )
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
new file mode 100644
index 0000000000..447fb0854e
--- /dev/null
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -0,0 +1,288 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple
+
+from tensorflow.keras import layers
+from tensorflow.keras.models import Sequential
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.layers.tensorflow import RepConvLayer
+from doctr.models.utils.tensorflow import conv_sequence
+
+from ...utils import load_pretrained_params
+
+__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+    "textnetfast_tiny": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_small": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_base": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+}
+
+
+class TextNetFast(Sequential):
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    Args:
+        num_blocks: number of resnet block in each stage
+        output_channels: number of channels in each stage
+        stage_conv: whether to add a conv_sequence after each stage
+        stage_pooling: pooling to add after each stage (if None, no pooling)
+        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
+        stem_channels: number of output channels of the stem convolutions
+        attn_module: attention module to use in each stage
+        include_top: whether the classifier head should be instantiated
+        num_classes: number of output classes
+    """
+
+    def __init__(
+        self,
+        stage1: Dict,
+        stage2: Dict,
+        stage3: Dict,
+        stage4: Dict,
+        include_top: bool = True,
+        num_classes: int = 1000,
+        input_shape: Tuple[int, int, int] = (32, 32, 3),
+        cfg: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        _layers = [
+            *conv_sequence(
+                input_shape=input_shape, out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2
+            )
+        ]
+
+        for stage in [stage1, stage2, stage3, stage4]:
+            stage_ = Sequential([RepConvLayer(**params) for params in stage])
+            _layers.extend([stage_])
+
+        if include_top:
+            _layers.extend(
+                [layers.GlobalAveragePooling2D(), layers.Flatten(), layers.Dense(num_classes, activation=None)]
+            )
+
+        super().__init__(_layers)
+        self.cfg = cfg
+
+
+def _textnetfast(
+    arch: str,
+    pretrained: bool,
+    arch_fn,
+    ignore_keys: Optional[List[str]] = None,
+    **kwargs: Any,
+) -> TextNetFast:
+    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+    _cfg = deepcopy(default_cfgs[arch])
+    _cfg["num_classes"] = kwargs["num_classes"]
+    _cfg["classes"] = kwargs["classes"]
+    kwargs.pop("classes")
+
+    # Build the model
+    model = arch_fn(**kwargs)
+    # Load pretrained parameters
+    if pretrained:
+        # The number of classes is not the same as the number of classes in the pretrained model =>
+        # remove the last layer weights
+        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
+        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+    model.cfg = _cfg
+
+    return model
+
+
+def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_tiny
+    >>> model = textnetfast_tiny(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnetfast(
+        "textnetfast_tiny",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage2=[
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+        ],
+        stage3=[
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+        ],
+        stage4=[
+            {"out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 512, "kernel_size": [3, 3], "stride": 1},
+        ],
+        ignore_keys=["10.weight", "10.bias"],
+        **kwargs,
+    )
+
+
+def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector
+    with Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_small
+    >>> model = textnetfast_small(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_small",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+        ],
+        stage2=[
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage3=[
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage4=[
+            {"out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+        ],
+        ignore_keys=["10.weight", "10.bias"],
+        **kwargs,
+    )
+
+
+def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_base
+    >>> model = textnetfast_base(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_base",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage2=[
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage3=[
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+        ],
+        stage4=[
+            {"out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+        ],
+        ignore_keys=["10.weight", "10.bias"],
+        **kwargs,
+    )
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index 62772e97a5..37234655a5 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -1,12 +1,12 @@
 import torch.nn as nn
 
+
 class RepConvLayer(nn.Module):
-    def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1, groups=1):
+    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1):
         super(RepConvLayer, self).__init__()
-        
-        padding = (int(((kernel_size[0] - 1) * dilation) / 2),
-                   int(((kernel_size[1] - 1) * dilation) / 2))
-        
+
+        padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2))
+
         self.activation = nn.ReLU(inplace=True)
         self.main_conv = nn.Conv2d(
             in_channels,
@@ -14,7 +14,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1,
             kernel_size=kernel_size,
             stride=stride,
             padding=padding,
-            dilation=dilation, 
+            dilation=dilation,
             groups=groups,
             bias=False,
         )
@@ -22,8 +22,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1,
 
         ver_pad = (int(((kernel_size[0] - 1) * dilation) / 2), 0)
         hor_pad = (0, int(((kernel_size[1] - 1) * dilation) / 2))
-        
-        
+
         if kernel_size[1] != 1:
             self.ver_conv = nn.Conv2d(
                 in_channels,
@@ -31,14 +30,14 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1,
                 kernel_size=(kernel_size[0], 1),
                 stride=stride,
                 padding=ver_pad,
-                dilation=dilation, 
+                dilation=dilation,
                 groups=groups,
-                bias=False
+                bias=False,
             )
             self.ver_bn = nn.BatchNorm2d(out_channels)
         else:
-             self.ver_conv, self.ver_bn = None, None
-             
+            self.ver_conv, self.ver_bn = None, None
+
         if kernel_size[0] != 1:
             self.hor_conv = nn.Conv2d(
                 in_channels,
@@ -46,18 +45,17 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1,
                 kernel_size=(1, kernel_size[1]),
                 stride=stride,
                 padding=hor_pad,
-                dilation=dilation, 
+                dilation=dilation,
                 groups=groups,
-                bias=False
+                bias=False,
             )
             self.hor_bn = nn.BatchNorm2d(out_channels)
         else:
             self.hor_conv, self.hor_bn = None, None
-        
+
         self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
 
     def forward(self, input):
-    
         main_outputs = self.main_bn(self.main_conv(input))
         vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
         horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index eca30f75ec..6b824e19e9 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -1,17 +1,12 @@
-import math
-from typing import Any, Tuple
-
-import tensorflow as tf
 from tensorflow.keras import layers
 
 
 class RepConvLayer(layers.Layer):
-    def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1, groups=1):
+    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1):
         super(RepConvLayer, self).__init__()
-        
-        padding = (int(((kernel_size[0] - 1) * dilation) / 2),
-                   int(((kernel_size[1] - 1) * dilation) / 2))
-        
+
+        padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2))
+
         self.activation = layers.ReLU()
         self.main_conv = layers.Conv2D(
             filters=out_channels,
@@ -21,49 +16,47 @@ def __init__(self, in_channels, out_channels, kernel_size, stride,  dilation=1,
             dilation_rate=dilation,
             groups=groups,
             use_bias=False,
-            input_shape=(None, None, in_channels) 
+            input_shape=(None, None, in_channels),
         )
 
         self.main_bn = layers.BatchNormalization()
 
-        ver_pad = (int(((kernel_size[0] - 1) * dilation) / 2), 0)
-        hor_pad = (0, int(((kernel_size[1] - 1) * dilation) / 2))
-        
-        
+        (int(((kernel_size[0] - 1) * dilation) / 2), 0)
+        (0, int(((kernel_size[1] - 1) * dilation) / 2))
+
         if kernel_size[1] != 1:
             self.ver_conv = layers.Conv2D(
-            filters=out_channels,
-            kernel_size=(kernel_size[0], 1),
-            strides=(stride, 1),
-            padding='valid',  
-            dilation_rate=(dilation, 1),
-            groups=groups,
-            use_bias=False,
-            input_shape=(None, None, in_channels) 
+                filters=out_channels,
+                kernel_size=(kernel_size[0], 1),
+                strides=(stride, 1),
+                padding="valid",
+                dilation_rate=(dilation, 1),
+                groups=groups,
+                use_bias=False,
+                input_shape=(None, None, in_channels),
             )
             self.ver_bn = layers.BatchNormalization()
         else:
-             self.ver_conv, self.ver_bn = None, None
-             
+            self.ver_conv, self.ver_bn = None, None
+
         if kernel_size[0] != 1:
             self.hor_conv = layers.Conv2D(
                 filters=out_channels,
                 kernel_size=(1, kernel_size[1]),
                 strides=stride,
-                padding='valid',  # TensorFlow utilise 'valid' pour l'équivalent de 'same' de PyTorch
+                padding="valid",  # TensorFlow utilise 'valid' pour l'équivalent de 'same' de PyTorch
                 dilation_rate=dilation,
                 groups=groups,
                 use_bias=False,
-                input_shape=(None, None, in_channels)  # Spécifiez la forme de l'entrée ici
-                )
+                input_shape=(None, None, in_channels),  # Spécifiez la forme de l'entrée ici
+            )
             self.hor_bn = layers.BatchNormalization()
         else:
             self.hor_conv, self.hor_bn = None, None
-        
+
         self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
 
     def forward(self, input):
-    
         main_outputs = self.main_bn(self.main_conv(input))
         vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
         horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0
diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py
index 346b320259..8490c09f11 100644
--- a/doctr/models/utils/tensorflow.py
+++ b/doctr/models/utils/tensorflow.py
@@ -66,7 +66,7 @@ def conv_sequence(
     bn: bool = False,
     padding: str = "same",
     kernel_initializer: str = "he_normal",
-    **kwargs: Any,    
+    **kwargs: Any,
 ) -> List[layers.Layer]:
     """Builds a convolutional-based layer sequence
 
diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py
index 5f2f7cc2ed..9184b9d251 100644
--- a/tests/pytorch/test_models_classification_pt.py
+++ b/tests/pytorch/test_models_classification_pt.py
@@ -44,9 +44,9 @@ def _test_classification(model, input_shape, output_size, batch_size=2):
         ["vit_b", (3, 32, 32), (126,)],
         # Check that the interpolation of positional embeddings for vit models works correctly
         ["vit_s", (3, 64, 64), (126,)],
-        ["textnetfast_tiny",(3 ,32, 32), (126,)],
-        ["textnetfast_small",(3 ,32, 32), (126,)],
-        ["textnetfast_base",(3 ,32, 32), (126,)],
+        ["textnetfast_tiny", (3, 32, 32), (126,)],
+        ["textnetfast_small", (3, 32, 32), (126,)],
+        ["textnetfast_base", (3, 32, 32), (126,)],
     ],
 )
 def test_classification_architectures(arch_name, input_shape, output_size):
@@ -128,9 +128,9 @@ def test_crop_orientation_model(mock_text_box):
         ["mobilenet_v3_large", (3, 32, 32), (126,)],
         ["mobilenet_v3_small_orientation", (3, 128, 128), (4,)],
         ["vit_b", (3, 32, 32), (126,)],
-        ["textnetfast_tiny",(3 ,32, 32), (126,)],
-        ["textnetfast_small",(3 ,32, 32), (126,)],
-        ["textnetfast_base",(3 ,32, 32), (126,)],
+        ["textnetfast_tiny", (3, 32, 32), (126,)],
+        ["textnetfast_small", (3, 32, 32), (126,)],
+        ["textnetfast_base", (3, 32, 32), (126,)],
     ],
 )
 def test_models_onnx_export(arch_name, input_shape, output_size):
diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py
index 84eefb9f9d..e38386ba40 100644
--- a/tests/tensorflow/test_models_classification_tf.py
+++ b/tests/tensorflow/test_models_classification_tf.py
@@ -139,24 +139,21 @@ def test_crop_orientation_model(mock_text_box):
             (126,),
             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
         ),
-        pytest.param(
-             "textnetfast_tiny", 
-             (32, 32, 3), 
-             (126,),
-             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
-        ),
-        pytest.param(
-             "textnetfast_small", 
-             (32, 32, 3), 
-             (126,),
-             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
-        ),
-        pytest.param(
-             "textnetfast_base", 
-             (32, 32, 3), 
-             (126,),
-             marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"),
-        ),
+        [
+            "textnetfast_tiny",
+            (32, 32, 3),
+            (126,),
+        ],
+        [
+            "textnetfast_small",
+            (32, 32, 3),
+            (126,),
+        ],
+        [
+            "textnetfast_base",
+            (32, 32, 3),
+            (126,),
+        ],
     ],
 )
 def test_models_onnx_export(arch_name, input_shape, output_size):

From fd3d85b461374bb5cc8591854415e13df80fb527 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 13:12:23 +0200
Subject: [PATCH 12/44] some changes

---
 doctr/models/classification/textnet_fast/tensorflow.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index 447fb0854e..2e27ab44ce 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -63,6 +63,7 @@ def __init__(
         num_classes: int = 1000,
         input_shape: Tuple[int, int, int] = (32, 32, 3),
         cfg: Optional[Dict[str, Any]] = None,
+        **kwargs: Any,
     ) -> None:
         _layers = [
             *conv_sequence(

From 10e7e05d742ac431aa89b193350007513d565e34 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 13:14:53 +0200
Subject: [PATCH 13/44] some changes

---
 doctr/models/classification/textnet_fast/pytorch.py    | 8 ++++----
 doctr/models/classification/textnet_fast/tensorflow.py | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 2a262b5598..288b22b130 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -55,10 +55,10 @@ class TextNetFast(nn.Sequential):
 
     def __init__(
         self,
-        stage1: Dict,
-        stage2: Dict,
-        stage3: Dict,
-        stage4: Dict,
+        stage1: Dict[str, Union[int, List[int]]],
+        stage2: Dict[str, Union[int, List[int]]],
+        stage3: Dict[str, Union[int, List[int]]],
+        stage4: Dict[str, Union[int, List[int]]],
         include_top: bool = True,
         num_classes: int = 1000,
         cfg: Optional[Dict[str, Any]] = None,
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index 2e27ab44ce..f00916aed5 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -55,10 +55,10 @@ class TextNetFast(Sequential):
 
     def __init__(
         self,
-        stage1: Dict,
-        stage2: Dict,
-        stage3: Dict,
-        stage4: Dict,
+        stage1: Dict[str, Union[int, List[int]]],
+        stage2: Dict[str, Union[int, List[int]]],
+        stage3: Dict[str, Union[int, List[int]]],
+        stage4: Dict[str, Union[int, List[int]]],
         include_top: bool = True,
         num_classes: int = 1000,
         input_shape: Tuple[int, int, int] = (32, 32, 3),

From d49c3f2f2948e7e07cf3caf781ff34cd8833bf12 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 14:51:38 +0200
Subject: [PATCH 14/44] some changes3

---
 .../classification/textnet_fast/pytorch.py       | 16 +++++++---------
 .../classification/textnet_fast/tensorflow.py    | 16 +++++++---------
 2 files changed, 14 insertions(+), 18 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 288b22b130..23204c1a73 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -42,15 +42,13 @@ class TextNetFast(nn.Sequential):
     Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
 
     Args:
-        num_blocks: number of resnet block in each stage
-        output_channels: number of channels in each stage
-        stage_conv: whether to add a conv_sequence after each stage
-        stage_pooling: pooling to add after each stage (if None, no pooling)
-        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
-        stem_channels: number of output channels of the stem convolutions
-        attn_module: attention module to use in each stage
-        include_top: whether the classifier head should be instantiated
-        num_classes: number of output classes
+        stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1
+        stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2
+        stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3
+        stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4
+        include_top (bool, optional): Whether to include the classifier head. Defaults to True.
+        num_classes (int, optional): Number of output classes. Defaults to 1000.
+        cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
     """
 
     def __init__(
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index f00916aed5..ed40cac57d 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -42,15 +42,13 @@ class TextNetFast(Sequential):
     Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
 
     Args:
-        num_blocks: number of resnet block in each stage
-        output_channels: number of channels in each stage
-        stage_conv: whether to add a conv_sequence after each stage
-        stage_pooling: pooling to add after each stage (if None, no pooling)
-        origin_stem: whether to use the orginal ResNet stem or ResNet-31's
-        stem_channels: number of output channels of the stem convolutions
-        attn_module: attention module to use in each stage
-        include_top: whether the classifier head should be instantiated
-        num_classes: number of output classes
+        stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1
+        stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2
+        stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3
+        stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4
+        include_top (bool, optional): Whether to include the classifier head. Defaults to True.
+        num_classes (int, optional): Number of output classes. Defaults to 1000.
+        cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
     """
 
     def __init__(

From 6599951a69a1950a8ebdcd4f9e55957eb264ba1f Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 14:57:08 +0200
Subject: [PATCH 15/44] [skip ci] some changes 4

---
 doctr/models/modules/layers/pytorch.py | 53 +++++++++++++++-----------
 1 file changed, 31 insertions(+), 22 deletions(-)

diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index 37234655a5..ede7813da3 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -1,9 +1,28 @@
+from typing import Any, List, Tuple, Union
+
+import torch
 import torch.nn as nn
 
+__all__ = ["RepConvLayer"]
+
 
 class RepConvLayer(nn.Module):
-    def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1):
-        super(RepConvLayer, self).__init__()
+    """Reparameterized Convolutional Layer"""
+
+    def __init__(
+        self, in_channels: int, out_channels: int, kernel_size: Union[List[int], Tuple[int, int], int], **kwargs: Any
+    ) -> None:
+        super().__init__()
+
+        kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+        dilation = kwargs.get("dilation", 1)
+        stride = kwargs.get("stride", 1)
+        kwargs.pop("padding", None)
+        kwargs.pop("bias", None)
+
+        self.hor_conv, self.hor_bn = None, None
+        self.ver_conv, self.ver_bn = None, None
 
         padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2))
 
@@ -14,29 +33,22 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
             kernel_size=kernel_size,
             stride=stride,
             padding=padding,
-            dilation=dilation,
-            groups=groups,
             bias=False,
+            **kwargs,
         )
         self.main_bn = nn.BatchNorm2d(out_channels)
 
-        ver_pad = (int(((kernel_size[0] - 1) * dilation) / 2), 0)
-        hor_pad = (0, int(((kernel_size[1] - 1) * dilation) / 2))
-
         if kernel_size[1] != 1:
             self.ver_conv = nn.Conv2d(
                 in_channels,
                 out_channels,
                 kernel_size=(kernel_size[0], 1),
                 stride=stride,
-                padding=ver_pad,
-                dilation=dilation,
-                groups=groups,
+                padding=(int(((kernel_size[0] - 1) * dilation) / 2), 0),
                 bias=False,
+                **kwargs,
             )
             self.ver_bn = nn.BatchNorm2d(out_channels)
-        else:
-            self.ver_conv, self.ver_bn = None, None
 
         if kernel_size[0] != 1:
             self.hor_conv = nn.Conv2d(
@@ -44,21 +56,18 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                 out_channels,
                 kernel_size=(1, kernel_size[1]),
                 stride=stride,
-                padding=hor_pad,
-                dilation=dilation,
-                groups=groups,
+                padding=(0, int(((kernel_size[1] - 1) * dilation) / 2)),
                 bias=False,
+                **kwargs,
             )
             self.hor_bn = nn.BatchNorm2d(out_channels)
-        else:
-            self.hor_conv, self.hor_bn = None, None
 
         self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
 
-    def forward(self, input):
-        main_outputs = self.main_bn(self.main_conv(input))
-        vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
-        horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0
-        id_out = self.rbr_identity(input) if self.rbr_identity is not None else 0
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        main_outputs = self.main_bn(self.main_conv(x))
+        vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None else 0
+        horizontal_outputs = self.hor_bn(self.hor_conv(x)) if self.hor_conv is not None else 0
+        id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
 
         return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

From 491ddd5c3fa819716b098f3615ddfde56425b580 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 14:58:19 +0200
Subject: [PATCH 16/44] [skip ci] some changes 5

---
 doctr/models/modules/layers/tensorflow.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index 6b824e19e9..e2c5bc4b06 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -56,10 +56,14 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
 
         self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
 
-    def forward(self, input):
-        main_outputs = self.main_bn(self.main_conv(input))
-        vertical_outputs = self.ver_bn(self.ver_conv(input)) if self.ver_conv is not None else 0
-        horizontal_outputs = self.hor_bn(self.hor_conv(input)) if self.hor_conv is not None else 0
-        id_out = self.rbr_identity(input) if self.rbr_identity is not None else 0
+    def call(
+        self,
+        x: tf.Tensor,
+        **kwargs: Any,
+    ) -> tf.Tensor:
+        main_outputs = self.main_bn(self.main_conv(x, **kwargs), **kwargs)
+        vertical_outputs = self.ver_bn(self.ver_conv(x, **kwargs), **kwargs) if self.ver_conv is not None else 0
+        horizontal_outputs = self.hor_bn(self.hor_conv(x, **kwargs), **kwargs) if self.hor_conv is not None else 0
+        id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
 
         return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

From 2f2e769f5c6496c86e993923960f1d0e0df387a9 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 30 Aug 2023 16:12:23 +0200
Subject: [PATCH 17/44] [skip ci] some changes 5

---
 .../classification/textnet_fast/pytorch.py    |   2 +-
 .../classification/textnet_fast/tensorflow.py | 144 +++++++++---------
 doctr/models/modules/layers/pytorch.py        |   4 +-
 doctr/models/modules/layers/tensorflow.py     |  76 ++++-----
 4 files changed, 116 insertions(+), 110 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 23204c1a73..002beb6b27 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -5,7 +5,7 @@
 
 
 from copy import deepcopy
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 
 import torch.nn as nn
 
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index ed40cac57d..246c9ea79d 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -4,7 +4,7 @@
 # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
 
 from copy import deepcopy
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from tensorflow.keras import layers
 from tensorflow.keras.models import Sequential
@@ -134,27 +134,27 @@ def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
         pretrained,
         TextNetFast,
         stage1=[
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
         ],
         stage2=[
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
         ],
         stage3=[
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
         ],
         stage4=[
-            {"out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 512, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
         ],
         ignore_keys=["10.weight", "10.bias"],
         **kwargs,
@@ -184,35 +184,35 @@ def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
         pretrained,
         TextNetFast,
         stage1=[
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
         ],
         stage2=[
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
         ],
         stage3=[
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
         ],
         stage4=[
-            {"out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
         ],
         ignore_keys=["10.weight", "10.bias"],
         **kwargs,
@@ -242,45 +242,45 @@ def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
         pretrained,
         TextNetFast,
         stage1=[
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
         ],
         stage2=[
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
         ],
         stage3=[
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
         ],
         stage4=[
-            {"out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
         ],
         ignore_keys=["10.weight", "10.bias"],
         **kwargs,
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index ede7813da3..e3901a40ff 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -31,11 +31,11 @@ def __init__(
             in_channels,
             out_channels,
             kernel_size=kernel_size,
-            stride=stride,
             padding=padding,
             bias=False,
             **kwargs,
         )
+
         self.main_bn = nn.BatchNorm2d(out_channels)
 
         if kernel_size[1] != 1:
@@ -43,7 +43,6 @@ def __init__(
                 in_channels,
                 out_channels,
                 kernel_size=(kernel_size[0], 1),
-                stride=stride,
                 padding=(int(((kernel_size[0] - 1) * dilation) / 2), 0),
                 bias=False,
                 **kwargs,
@@ -55,7 +54,6 @@ def __init__(
                 in_channels,
                 out_channels,
                 kernel_size=(1, kernel_size[1]),
-                stride=stride,
                 padding=(0, int(((kernel_size[1] - 1) * dilation) / 2)),
                 bias=False,
                 **kwargs,
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index e2c5bc4b06..891c5cad4e 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -1,55 +1,59 @@
 from tensorflow.keras import layers
+import tensorflow as tf
 
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 class RepConvLayer(layers.Layer):
     def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1):
         super(RepConvLayer, self).__init__()
 
+        kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
         padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2))
 
         self.activation = layers.ReLU()
-        self.main_conv = layers.Conv2D(
-            filters=out_channels,
-            kernel_size=kernel_size,
-            strides=stride,
-            padding=padding,
-            dilation_rate=dilation,
-            groups=groups,
-            use_bias=False,
-            input_shape=(None, None, in_channels),
-        )
-
-        self.main_bn = layers.BatchNormalization()
-
-        (int(((kernel_size[0] - 1) * dilation) / 2), 0)
-        (0, int(((kernel_size[1] - 1) * dilation) / 2))
-
-        if kernel_size[1] != 1:
-            self.ver_conv = layers.Conv2D(
+        self.main_conv = tf.keras.Sequential([
+            layers.ZeroPadding2D(padding=padding),
+            layers.Conv2D(
                 filters=out_channels,
-                kernel_size=(kernel_size[0], 1),
-                strides=(stride, 1),
-                padding="valid",
-                dilation_rate=(dilation, 1),
+                kernel_size=kernel_size,
+                strides=stride,
+                dilation_rate=dilation,
                 groups=groups,
                 use_bias=False,
                 input_shape=(None, None, in_channels),
-            )
+        )])
+
+        self.main_bn = layers.BatchNormalization()
+
+        if kernel_size[1] != 1:
+            self.ver_conv = tf.keras.Sequential([
+                layers.ZeroPadding2D(padding=padding),
+                layers.Conv2D(
+                    filters=out_channels,
+                    kernel_size=(kernel_size[0], 1),
+                    strides=(stride, 1),
+                    dilation_rate=(dilation, 1),
+                    groups=groups,
+                    use_bias=False,
+                    input_shape=(None, None, in_channels))])
+
             self.ver_bn = layers.BatchNormalization()
         else:
             self.ver_conv, self.ver_bn = None, None
 
         if kernel_size[0] != 1:
-            self.hor_conv = layers.Conv2D(
-                filters=out_channels,
-                kernel_size=(1, kernel_size[1]),
-                strides=stride,
-                padding="valid",  # TensorFlow utilise 'valid' pour l'équivalent de 'same' de PyTorch
-                dilation_rate=dilation,
-                groups=groups,
-                use_bias=False,
-                input_shape=(None, None, in_channels),  # Spécifiez la forme de l'entrée ici
-            )
+            self.hor_conv = tf.keras.Sequential([
+                layers.ZeroPadding2D(padding=padding),
+                layers.Conv2D(
+                    filters=out_channels,
+                    kernel_size=(1, kernel_size[1]),
+                    strides=stride,
+                    dilation_rate=dilation,
+                    groups=groups,
+                    use_bias=False,
+                    input_shape=(None, None, in_channels))])
+
             self.hor_bn = layers.BatchNormalization()
         else:
             self.hor_conv, self.hor_bn = None, None
@@ -66,4 +70,8 @@ def call(
         horizontal_outputs = self.hor_bn(self.hor_conv(x, **kwargs), **kwargs) if self.hor_conv is not None else 0
         id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
 
-        return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
+        p = main_outputs + vertical_outputs
+        q = horizontal_outputs + id_out
+        r = p + q
+
+        return self.activation(r)

From f315122129afabf6a086c52a48e4456be3fa47b0 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Thu, 31 Aug 2023 10:16:38 +0200
Subject: [PATCH 18/44] [skip ci] removing ingore_keys in tf textnet model

---
 .../classification/textnet_fast/tensorflow.py |  2 -
 doctr/models/modules/layers/tensorflow.py     | 76 +++++++++++--------
 2 files changed, 44 insertions(+), 34 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index 246c9ea79d..71baa842b8 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -86,7 +86,6 @@ def _textnetfast(
     arch: str,
     pretrained: bool,
     arch_fn,
-    ignore_keys: Optional[List[str]] = None,
     **kwargs: Any,
 ) -> TextNetFast:
     kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
@@ -103,7 +102,6 @@ def _textnetfast(
     if pretrained:
         # The number of classes is not the same as the number of classes in the pretrained model =>
         # remove the last layer weights
-        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
         load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
 
     model.cfg = _cfg
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index 891c5cad4e..364414462c 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -1,7 +1,8 @@
-from tensorflow.keras import layers
+from typing import Any
+
 import tensorflow as tf
+from tensorflow.keras import layers
 
-from typing import Any, Dict, List, Optional, Tuple, Union
 
 class RepConvLayer(layers.Layer):
     def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1):
@@ -12,47 +13,58 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
         padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2))
 
         self.activation = layers.ReLU()
-        self.main_conv = tf.keras.Sequential([
-            layers.ZeroPadding2D(padding=padding),
-            layers.Conv2D(
-                filters=out_channels,
-                kernel_size=kernel_size,
-                strides=stride,
-                dilation_rate=dilation,
-                groups=groups,
-                use_bias=False,
-                input_shape=(None, None, in_channels),
-        )])
-
-        self.main_bn = layers.BatchNormalization()
-
-        if kernel_size[1] != 1:
-            self.ver_conv = tf.keras.Sequential([
+        self.main_conv = tf.keras.Sequential(
+            [
                 layers.ZeroPadding2D(padding=padding),
                 layers.Conv2D(
                     filters=out_channels,
-                    kernel_size=(kernel_size[0], 1),
-                    strides=(stride, 1),
-                    dilation_rate=(dilation, 1),
+                    kernel_size=kernel_size,
+                    strides=stride,
+                    dilation_rate=dilation,
                     groups=groups,
                     use_bias=False,
-                    input_shape=(None, None, in_channels))])
+                    input_shape=(None, None, in_channels),
+                ),
+            ]
+        )
+
+        self.main_bn = layers.BatchNormalization()
+
+        if kernel_size[1] != 1:
+            self.ver_conv = tf.keras.Sequential(
+                [
+                    layers.ZeroPadding2D(padding=padding),
+                    layers.Conv2D(
+                        filters=out_channels,
+                        kernel_size=(kernel_size[0], 1),
+                        strides=(stride, 1),
+                        dilation_rate=(dilation, 1),
+                        groups=groups,
+                        use_bias=False,
+                        input_shape=(None, None, in_channels),
+                    ),
+                ]
+            )
 
             self.ver_bn = layers.BatchNormalization()
         else:
             self.ver_conv, self.ver_bn = None, None
 
         if kernel_size[0] != 1:
-            self.hor_conv = tf.keras.Sequential([
-                layers.ZeroPadding2D(padding=padding),
-                layers.Conv2D(
-                    filters=out_channels,
-                    kernel_size=(1, kernel_size[1]),
-                    strides=stride,
-                    dilation_rate=dilation,
-                    groups=groups,
-                    use_bias=False,
-                    input_shape=(None, None, in_channels))])
+            self.hor_conv = tf.keras.Sequential(
+                [
+                    layers.ZeroPadding2D(padding=padding),
+                    layers.Conv2D(
+                        filters=out_channels,
+                        kernel_size=(1, kernel_size[1]),
+                        strides=stride,
+                        dilation_rate=dilation,
+                        groups=groups,
+                        use_bias=False,
+                        input_shape=(None, None, in_channels),
+                    ),
+                ]
+            )
 
             self.hor_bn = layers.BatchNormalization()
         else:

From ad1bc738d071742d8f70bdb6af32a8e17d130ab5 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Thu, 31 Aug 2023 10:17:33 +0200
Subject: [PATCH 19/44] [skip ci] removing ingore_keys in tf textnet model

---
 doctr/models/classification/textnet_fast/tensorflow.py | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index 71baa842b8..c5a5cb393c 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -98,8 +98,7 @@ def _textnetfast(
 
     # Build the model
     model = arch_fn(**kwargs)
-    # Load pretrained parameters
-    if pretrained:
+    # Load pretrained parameters 
         # The number of classes is not the same as the number of classes in the pretrained model =>
         # remove the last layer weights
         load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
@@ -154,7 +153,6 @@ def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
         ],
-        ignore_keys=["10.weight", "10.bias"],
         **kwargs,
     )
 
@@ -212,7 +210,6 @@ def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
         ],
-        ignore_keys=["10.weight", "10.bias"],
         **kwargs,
     )
 
@@ -280,6 +277,5 @@ def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
         ],
-        ignore_keys=["10.weight", "10.bias"],
         **kwargs,
     )

From 0bf73b3fe96926ab299992151bb151d290e26ab4 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Thu, 31 Aug 2023 11:12:05 +0200
Subject: [PATCH 20/44] [skip ci] first layers of the model to a single block

---
 .../models/classification/textnet_fast/pytorch.py  | 14 +++++++-------
 .../classification/textnet_fast/tensorflow.py      |  1 +
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 002beb6b27..1527a2b224 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -62,10 +62,10 @@ def __init__(
         cfg: Optional[Dict[str, Any]] = None,
     ) -> None:
         _layers: List[nn.Module]
+        super().__init__()
         self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
-
-        _layers = [*self.first_conv]
-
+        self.first_conv = nn.Sequential(*self.first_conv)
+        _layers = [self.first_conv]
         for stage in [stage1, stage2, stage3, stage4]:
             stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
             _layers.extend([*stage_])
@@ -88,7 +88,7 @@ def __init__(
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
-
+        print(self)
 
 def _textnetfast(
     arch: str,
@@ -164,7 +164,7 @@ def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
         ],
-        ignore_keys=["10.weight", "10.bias"],
+        ignore_keys=["18.weight", "18.bias"],
         **kwargs,
     )
 
@@ -222,7 +222,7 @@ def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
         ],
-        ignore_keys=["10.weight", "10.bias"],
+        ignore_keys=["26.weight", "26.bias"],
         **kwargs,
     )
 
@@ -290,6 +290,6 @@ def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
         ],
-        ignore_keys=["10.weight", "10.bias"],
+        ignore_keys=["36.weight", "36.bias"],
         **kwargs,
     )
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index c5a5cb393c..65f5148494 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -99,6 +99,7 @@ def _textnetfast(
     # Build the model
     model = arch_fn(**kwargs)
     # Load pretrained parameters 
+    if pretrained:
         # The number of classes is not the same as the number of classes in the pretrained model =>
         # remove the last layer weights
         load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

From b131218ad827d03f57b4c734742808036e09e888 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Thu, 31 Aug 2023 11:21:12 +0200
Subject: [PATCH 21/44] [skip ci] creating blocks in layers

---
 doctr/models/classification/textnet_fast/pytorch.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 1527a2b224..c96870b28c 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -68,16 +68,17 @@ def __init__(
         _layers = [self.first_conv]
         for stage in [stage1, stage2, stage3, stage4]:
             stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
-            _layers.extend([*stage_])
+            stage_ = nn.Sequential(*stage_)
+            _layers.extend([stage_])
 
         if include_top:
-            _layers.extend(
-                [
+            classif_block = [
                     nn.AdaptiveAvgPool2d(1),
                     nn.Flatten(1),
                     nn.Linear(512, num_classes, bias=True),
                 ]
-            )
+            classif_block = nn.Sequential(*nn.ModuleList(classif_block))
+            _layers.extend([classif_block])
 
         super().__init__(*_layers)
         self.cfg = cfg

From f9a63c0b67c08e2ccceee3b67cb4b070318ed26c Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Thu, 31 Aug 2023 11:51:11 +0200
Subject: [PATCH 22/44] [skip ci] correction of some errors in make quality

---
 .../classification/textnet_fast/pytorch.py    | 17 ++++++-------
 .../classification/textnet_fast/tensorflow.py |  4 ++--
 doctr/models/modules/layers/pytorch.py        | 24 +++++++++++++------
 3 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index c96870b28c..c927b25025 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -66,6 +66,7 @@ def __init__(
         self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
         self.first_conv = nn.Sequential(*self.first_conv)
         _layers = [self.first_conv]
+
         for stage in [stage1, stage2, stage3, stage4]:
             stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
             stage_ = nn.Sequential(*stage_)
@@ -73,10 +74,10 @@ def __init__(
 
         if include_top:
             classif_block = [
-                    nn.AdaptiveAvgPool2d(1),
-                    nn.Flatten(1),
-                    nn.Linear(512, num_classes, bias=True),
-                ]
+                nn.AdaptiveAvgPool2d(1),
+                nn.Flatten(1),
+                nn.Linear(512, num_classes, bias=True),
+            ]
             classif_block = nn.Sequential(*nn.ModuleList(classif_block))
             _layers.extend([classif_block])
 
@@ -89,7 +90,7 @@ def __init__(
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
-        print(self)
+
 
 def _textnetfast(
     arch: str,
@@ -165,7 +166,7 @@ def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
         ],
-        ignore_keys=["18.weight", "18.bias"],
+        ignore_keys=["4.3.weight", "4.3.bias"],
         **kwargs,
     )
 
@@ -223,7 +224,7 @@ def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
         ],
-        ignore_keys=["26.weight", "26.bias"],
+        ignore_keys=["4.3.weight", "4.3.bias"],
         **kwargs,
     )
 
@@ -291,6 +292,6 @@ def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
             {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
             {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
         ],
-        ignore_keys=["36.weight", "36.bias"],
+        ignore_keys=["4.3.weight", "4.3.bias"],
         **kwargs,
     )
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index 65f5148494..c0e3527208 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -98,11 +98,11 @@ def _textnetfast(
 
     # Build the model
     model = arch_fn(**kwargs)
-    # Load pretrained parameters 
+    # Load pretrained parameters
     if pretrained:
         # The number of classes is not the same as the number of classes in the pretrained model =>
         # remove the last layer weights
-        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+        load_pretrained_params(model, default_cfgs[arch]["url"])
 
     model.cfg = _cfg
 
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index e3901a40ff..2bac0fbfe7 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -1,4 +1,4 @@
-from typing import Any, List, Tuple, Union
+from typing import Any, Union
 
 import torch
 import torch.nn as nn
@@ -9,9 +9,7 @@
 class RepConvLayer(nn.Module):
     """Reparameterized Convolutional Layer"""
 
-    def __init__(
-        self, in_channels: int, out_channels: int, kernel_size: Union[List[int], Tuple[int, int], int], **kwargs: Any
-    ) -> None:
+    def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Any], **kwargs: Any) -> None:
         super().__init__()
 
         kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
@@ -64,8 +62,20 @@ def __init__(
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         main_outputs = self.main_bn(self.main_conv(x))
-        vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None else 0
-        horizontal_outputs = self.hor_bn(self.hor_conv(x)) if self.hor_conv is not None else 0
-        id_out = self.rbr_identity(x) if self.rbr_identity is not None else 0
+
+        if self.ver_conv is not None and self.ver_bn is not None:
+            vertical_outputs = self.ver_bn(self.ver_conv(x))
+        else:
+            vertical_outputs = 0
+
+        if self.hor_bn is not None and self.hor_conv is not None:
+            horizontal_outputs = self.hor_bn(self.hor_conv(x))
+        else:
+            horizontal_outputs = 0
+
+        if self.rbr_identity is not None and self.ver_bn is not None:
+            id_out = self.rbr_identity(x)
+        else:
+            id_out = 0
 
         return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)

From 064ca6bd8f98442776efb03229449da86a2b7e6c Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Fri, 1 Sep 2023 20:24:39 +0200
Subject: [PATCH 23/44] [skip ci] adding fast inference fo textnet model

---
 Untitled.ipynb                                |  6 +++
 .../classification/textnet_fast/pytorch.py    | 33 +++++++-----
 .../classification/textnet_fast/tensorflow.py |  8 +--
 doctr/models/utils/pytorch.py                 | 43 +++++++++++++++
 inference.py                                  | 54 +++++++++++++++++++
 5 files changed, 126 insertions(+), 18 deletions(-)
 create mode 100644 Untitled.ipynb
 create mode 100644 inference.py

diff --git a/Untitled.ipynb b/Untitled.ipynb
new file mode 100644
index 0000000000..363fcab7ed
--- /dev/null
+++ b/Untitled.ipynb
@@ -0,0 +1,6 @@
+{
+ "cells": [],
+ "metadata": {},
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index c927b25025..3629a7213d 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -12,7 +12,7 @@
 from doctr.datasets import VOCABS
 from doctr.models.modules.layers.pytorch import RepConvLayer
 from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-
+from doctr.models.utils.pytorch import fuse_conv_bn, fuse_module, rep_model_convert
 from ...utils import load_pretrained_params
 
 __all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
@@ -53,24 +53,23 @@ class TextNetFast(nn.Sequential):
 
     def __init__(
         self,
-        stage1: Dict[str, Union[int, List[int]]],
-        stage2: Dict[str, Union[int, List[int]]],
-        stage3: Dict[str, Union[int, List[int]]],
-        stage4: Dict[str, Union[int, List[int]]],
+        stage1: List[Dict[str, Union[int, List[int]]]],
+        stage2: List[Dict[str, Union[int, List[int]]]],
+        stage3: List[Dict[str, Union[int, List[int]]]],
+        stage4: List[Dict[str, Union[int, List[int]]]],
         include_top: bool = True,
         num_classes: int = 1000,
         cfg: Optional[Dict[str, Any]] = None,
     ) -> None:
-        _layers: List[nn.Module]
+        _layers: List[Any]
         super().__init__()
-        self.first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
-        self.first_conv = nn.Sequential(*self.first_conv)
+        first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
+        self.first_conv = nn.Sequential(*first_conv)
         _layers = [self.first_conv]
 
         for stage in [stage1, stage2, stage3, stage4]:
-            stage_ = nn.ModuleList([RepConvLayer(**params) for params in stage])
-            stage_ = nn.Sequential(*stage_)
-            _layers.extend([stage_])
+            self.stage_ = nn.Sequential(*[RepConvLayer(**params) for params in stage])  # type: ignore[arg-type]
+            _layers.extend([self.stage_])
 
         if include_top:
             classif_block = [
@@ -78,8 +77,8 @@ def __init__(
                 nn.Flatten(1),
                 nn.Linear(512, num_classes, bias=True),
             ]
-            classif_block = nn.Sequential(*nn.ModuleList(classif_block))
-            _layers.extend([classif_block])
+            classif_block_ = nn.Sequential(*nn.ModuleList(classif_block))
+            _layers.extend([classif_block_])
 
         super().__init__(*_layers)
         self.cfg = cfg
@@ -105,8 +104,10 @@ def _textnetfast(
     _cfg = deepcopy(default_cfgs[arch])
     _cfg["num_classes"] = kwargs["num_classes"]
     _cfg["classes"] = kwargs["classes"]
+    training = kwargs["training"]
     kwargs.pop("classes")
-
+    kwargs.pop("training")
+    
     # Build the model
     model = arch_fn(**kwargs)
     # Load pretrained parameters
@@ -118,6 +119,10 @@ def _textnetfast(
 
     model.cfg = _cfg
 
+    if training is False:
+        model = rep_model_convert(model)
+        model = fuse_module(model)
+    
     return model
 
 
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index c0e3527208..72628938de 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -53,10 +53,10 @@ class TextNetFast(Sequential):
 
     def __init__(
         self,
-        stage1: Dict[str, Union[int, List[int]]],
-        stage2: Dict[str, Union[int, List[int]]],
-        stage3: Dict[str, Union[int, List[int]]],
-        stage4: Dict[str, Union[int, List[int]]],
+        stage1: List[Dict[str, Union[int, List[int]]]],
+        stage2: List[Dict[str, Union[int, List[int]]]],
+        stage3: List[Dict[str, Union[int, List[int]]]],
+        stage4: List[Dict[str, Union[int, List[int]]]],
         include_top: bool = True,
         num_classes: int = 1000,
         input_shape: Tuple[int, int, int] = (32, 32, 3),
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index b24030ca17..8093ac1bf0 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -150,3 +150,46 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
     )
     logging.info(f"Model exported to {model_name}.onnx")
     return f"{model_name}.onnx"
+
+
+def rep_model_convert(model:torch.nn.Module):
+    for module in model.modules():
+        if hasattr(module, 'switch_to_deploy'):
+            module.switch_to_deploy()
+            # print("switch_to_deploy")
+    return model
+    
+ 
+def fuse_conv_bn(conv, bn):
+    """During inference, the functionary of batch norm layers is turned off but
+    only the mean and var alone channels are used, which exposes the chance to
+    fuse it with the preceding conv layers to save computations and simplify
+    network structures."""
+    conv_w = conv.weight
+    conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean)
+
+    factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+    conv.weight = nn.Parameter(conv_w * factor.reshape([conv.out_channels, 1, 1, 1]))
+    conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+    return conv
+
+
+def fuse_module(m):
+    last_conv = None
+    last_conv_name = None
+
+    for name, child in m.named_children():
+        if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
+            if last_conv is None:  # only fuse BN that is after Conv
+                continue
+            fused_conv = fuse_conv_bn(last_conv, child)
+            m._modules[last_conv_name] = fused_conv
+            # To reduce changes, set BN as Identity instead of deleting it.
+            m._modules[name] = nn.Identity()
+            last_conv = None
+        elif isinstance(child, nn.Conv2d):
+            last_conv = child
+            last_conv_name = name
+        else:
+            fuse_module(child)
+    return m
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000..e23a836482
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,54 @@
+# git clone https://github.com/mindee/doctr.git
+# pip install -e doctr/.[tf]
+# conda install -y -c conda-forge weasyprint
+
+import json
+import os
+
+import tensorflow as tf
+
+from doctr.io import DocumentFile
+from doctr.models import ocr_predictor
+
+os.environ["USE_TF"] = "1"
+os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
+
+gpu_devices = tf.config.experimental.list_physical_devices("GPU")
+if any(gpu_devices):
+    tf.config.experimental.set_memory_growth(gpu_devices[0], True)
+
+
+def main(args):
+    # Load docTR model
+    model = ocr_predictor(det_arch=args.arch_detection, reco_arch=args.arch_recognition, pretrained=True)
+
+    # load image input file
+    single_img_doc = DocumentFile.from_images(args.input_file)
+
+    # inference
+    output = model(single_img_doc)
+
+    with open(args.output_file, "w") as f:
+        json.dump(output.export(), f)
+
+
+def parse_args():
+    import argparse
+
+    parser = argparse.ArgumentParser(
+        description="docTR inference image script(TensorFlow)",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument("--arch_recognition", type=str, help="text-detection model")
+    parser.add_argument("--arch_detection", type=str, help="text-recognition model")
+    parser.add_argument("--input_file", type=str, help="path of image file")
+    parser.add_argument("--output_file", type=str, help="path of output file")
+    args = parser.parse_args()
+
+    return args
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    main(args)

From 53c144b239b4216dfb3b1fa93b82ee9994a9ebaf Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Sun, 3 Sep 2023 13:36:14 +0200
Subject: [PATCH 24/44] [skip ci] some changes 6

---
 doctr/models/classification/textnet_fast/pytorch.py | 11 +++++------
 doctr/models/utils/pytorch.py                       |  9 ++++-----
 2 files changed, 9 insertions(+), 11 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 3629a7213d..cddd676cb5 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -12,7 +12,8 @@
 from doctr.datasets import VOCABS
 from doctr.models.modules.layers.pytorch import RepConvLayer
 from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-from doctr.models.utils.pytorch import fuse_conv_bn, fuse_module, rep_model_convert
+from doctr.models.utils.pytorch import fuse_module, rep_model_convert
+
 from ...utils import load_pretrained_params
 
 __all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
@@ -104,10 +105,8 @@ def _textnetfast(
     _cfg = deepcopy(default_cfgs[arch])
     _cfg["num_classes"] = kwargs["num_classes"]
     _cfg["classes"] = kwargs["classes"]
-    training = kwargs["training"]
     kwargs.pop("classes")
-    kwargs.pop("training")
-    
+
     # Build the model
     model = arch_fn(**kwargs)
     # Load pretrained parameters
@@ -119,10 +118,10 @@ def _textnetfast(
 
     model.cfg = _cfg
 
-    if training is False:
+    if model.training is False:
         model = rep_model_convert(model)
         model = fuse_module(model)
-    
+
     return model
 
 
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index 8093ac1bf0..05747c2b82 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -152,14 +152,13 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
     return f"{model_name}.onnx"
 
 
-def rep_model_convert(model:torch.nn.Module):
+def rep_model_convert(model: torch.nn.Module):
     for module in model.modules():
-        if hasattr(module, 'switch_to_deploy'):
+        if hasattr(module, "switch_to_deploy"):
             module.switch_to_deploy()
-            # print("switch_to_deploy")
     return model
-    
- 
+
+
 def fuse_conv_bn(conv, bn):
     """During inference, the functionary of batch norm layers is turned off but
     only the mean and var alone channels are used, which exposes the chance to

From 774138dfb495a7c715388783078e838bd3e34005 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 4 Sep 2023 16:55:43 +0200
Subject: [PATCH 25/44] [skip ci] override eval method for speed up eval mode
 of TextNetFast

---
 .../classification/textnet_fast/pytorch.py    | 14 ++++++++-----
 doctr/models/utils/pytorch.py                 | 20 +++++++++++++++++++
 2 files changed, 29 insertions(+), 5 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index cddd676cb5..f2f3d5c887 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -90,7 +90,15 @@ def __init__(
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
-
+    def eval(self):
+        model = rep_model_convert(model)
+        model = fuse_module(model)
+        model.eval()
+        
+    def train():
+        model = rep_model_unconvert(model)
+        model = unfuse_module(model)
+        model.train()
 
 def _textnetfast(
     arch: str,
@@ -118,10 +126,6 @@ def _textnetfast(
 
     model.cfg = _cfg
 
-    if model.training is False:
-        model = rep_model_convert(model)
-        model = fuse_module(model)
-
     return model
 
 
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index 05747c2b82..1890020493 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -192,3 +192,23 @@ def fuse_module(m):
         else:
             fuse_module(child)
     return m
+
+def fuse_module(m):
+    last_conv = None
+    last_conv_name = None
+
+    for name, child in m.named_children():
+        if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
+            if last_conv is None:  # only fuse BN that is after Conv
+                continue
+            fused_conv = fuse_conv_bn(last_conv, child)
+            m._modules[last_conv_name] = fused_conv
+            # To reduce changes, set BN as Identity instead of deleting it.
+            m._modules[name] = nn.Identity()
+            last_conv = None
+        elif isinstance(child, nn.Conv2d):
+            last_conv = child
+            last_conv_name = name
+        else:
+            fuse_module(child)
+    return m

From a7e8d73b44286abd219e0671b2e7e07ad1d08d1c Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 4 Sep 2023 17:29:40 +0200
Subject: [PATCH 26/44] [skip ci] override eval and train mode for textNetFast
 done

---
 .../classification/textnet_fast/pytorch.py    | 21 +++++++-----
 doctr/models/utils/pytorch.py                 | 32 +++++++++++++++----
 2 files changed, 39 insertions(+), 14 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index f2f3d5c887..fe87a2c87e 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -12,7 +12,7 @@
 from doctr.datasets import VOCABS
 from doctr.models.modules.layers.pytorch import RepConvLayer
 from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-from doctr.models.utils.pytorch import fuse_module, rep_model_convert
+from doctr.models.utils.pytorch import fuse_module, rep_model_convert, unfuse_conv_bn, unfuse_module
 
 from ...utils import load_pretrained_params
 
@@ -90,15 +90,20 @@ def __init__(
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
+                
     def eval(self):
-        model = rep_model_convert(model)
-        model = fuse_module(model)
-        model.eval()
+        model = rep_model_convert(self)
+        model = fuse_module(self)
+        for param in self.parameters():
+            param.requires_grad = False
+        self.training = False
         
-    def train():
-        model = rep_model_unconvert(model)
-        model = unfuse_module(model)
-        model.train()
+        
+    def train(self):
+        model = unfuse_module(self)
+        for param in self.parameters():
+            param.requires_grad = True
+        self.training = True
 
 def _textnetfast(
     arch: str,
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index 1890020493..14c45e9423 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -168,8 +168,11 @@ def fuse_conv_bn(conv, bn):
     conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean)
 
     factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+    conv.old_weight = conv.weight
+    conv.old_biais = conv.bias
     conv.weight = nn.Parameter(conv_w * factor.reshape([conv.out_channels, 1, 1, 1]))
     conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+    
     return conv
 
 
@@ -193,22 +196,39 @@ def fuse_module(m):
             fuse_module(child)
     return m
 
-def fuse_module(m):
+
+def unfuse_conv_bn(conv, bn):
+    """During inference, the functionary of batch norm layers is turned off but
+    only the mean and var alone channels are used, which exposes the chance to
+    fuse it with the preceding conv layers to save computations and simplify
+    network structures."""
+    conv.weight = conv.old_weight
+    conv.bias = conv.old_biais
+    
+    return conv
+
+def unfuse_module(m):
     last_conv = None
     last_conv_name = None
 
     for name, child in m.named_children():
-        if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
+        if isinstance(child, (nn.Identity, nn.Identity)):
             if last_conv is None:  # only fuse BN that is after Conv
                 continue
-            fused_conv = fuse_conv_bn(last_conv, child)
-            m._modules[last_conv_name] = fused_conv
+            unfused_conv = unfuse_conv_bn(last_conv, child)
+            m._modules[last_conv_name] = unfused_conv
             # To reduce changes, set BN as Identity instead of deleting it.
-            m._modules[name] = nn.Identity()
+            m._modules[name] = nn.BatchNorm2d(unfused_conv.out_channels)
             last_conv = None
         elif isinstance(child, nn.Conv2d):
             last_conv = child
             last_conv_name = name
         else:
-            fuse_module(child)
+            unfuse_module(child)
     return m
+    
+def rep_model_convert(model: torch.nn.Module):
+    for module in model.modules():
+        if hasattr(module, "switch_to_deploy"):
+            module.switch_to_deploy()
+    return model

From ae5f7c427e8449a682f0a7243ff858ca927e6118 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 4 Sep 2023 17:43:24 +0200
Subject: [PATCH 27/44] [skip ci] make style and make quality Done

---
 .../models/classification/textnet_fast/pytorch.py  | 14 +++++++-------
 doctr/models/utils/pytorch.py                      | 11 +++--------
 2 files changed, 10 insertions(+), 15 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index fe87a2c87e..c5b1d0b4c8 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -12,7 +12,7 @@
 from doctr.datasets import VOCABS
 from doctr.models.modules.layers.pytorch import RepConvLayer
 from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-from doctr.models.utils.pytorch import fuse_module, rep_model_convert, unfuse_conv_bn, unfuse_module
+from doctr.models.utils.pytorch import fuse_module, rep_model_convert, unfuse_module
 
 from ...utils import load_pretrained_params
 
@@ -90,21 +90,21 @@ def __init__(
             elif isinstance(m, nn.BatchNorm2d):
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
-                
+
     def eval(self):
-        model = rep_model_convert(self)
-        model = fuse_module(self)
+        rep_model_convert(self)
+        fuse_module(self)
         for param in self.parameters():
             param.requires_grad = False
         self.training = False
-        
-        
+
     def train(self):
-        model = unfuse_module(self)
+        unfuse_module(self)
         for param in self.parameters():
             param.requires_grad = True
         self.training = True
 
+
 def _textnetfast(
     arch: str,
     pretrained: bool,
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index 14c45e9423..f9a603644d 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -172,7 +172,7 @@ def fuse_conv_bn(conv, bn):
     conv.old_biais = conv.bias
     conv.weight = nn.Parameter(conv_w * factor.reshape([conv.out_channels, 1, 1, 1]))
     conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
-    
+
     return conv
 
 
@@ -204,9 +204,10 @@ def unfuse_conv_bn(conv, bn):
     network structures."""
     conv.weight = conv.old_weight
     conv.bias = conv.old_biais
-    
+
     return conv
 
+
 def unfuse_module(m):
     last_conv = None
     last_conv_name = None
@@ -226,9 +227,3 @@ def unfuse_module(m):
         else:
             unfuse_module(child)
     return m
-    
-def rep_model_convert(model: torch.nn.Module):
-    for module in model.modules():
-        if hasattr(module, "switch_to_deploy"):
-            module.switch_to_deploy()
-    return model

From 7c0bba36b2f7fef21ef1e510886aca4f187af1f7 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 4 Sep 2023 23:00:36 +0200
Subject: [PATCH 28/44] [skip ci] changing eval and train method to switch
 repconvlayer

---
 .../classification/textnet_fast/pytorch.py    |  10 +-
 .../classification/textnet_fast/tensorflow.py | 282 ------------------
 doctr/models/modules/layers/pytorch.py        | 131 +++++++-
 doctr/models/utils/pytorch.py                 |  10 +-
 4 files changed, 143 insertions(+), 290 deletions(-)
 delete mode 100644 doctr/models/classification/textnet_fast/tensorflow.py

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index c5b1d0b4c8..5e96b39c04 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -12,7 +12,7 @@
 from doctr.datasets import VOCABS
 from doctr.models.modules.layers.pytorch import RepConvLayer
 from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-from doctr.models.utils.pytorch import fuse_module, rep_model_convert, unfuse_module
+from doctr.models.utils.pytorch import fuse_module, rep_model_convert, unfuse_module, rep_model_unconvert
 
 from ...utils import load_pretrained_params
 
@@ -92,14 +92,16 @@ def __init__(
                 nn.init.constant_(m.bias, 0)
 
     def eval(self):
-        rep_model_convert(self)
-        fuse_module(self)
+        self = rep_model_convert(self)
+        self = fuse_module(self)
         for param in self.parameters():
             param.requires_grad = False
         self.training = False
 
     def train(self):
-        unfuse_module(self)
+
+        self = unfuse_module(self)
+        self = rep_model_unconvert(self)
         for param in self.parameters():
             param.requires_grad = True
         self.training = True
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
deleted file mode 100644
index 72628938de..0000000000
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ /dev/null
@@ -1,282 +0,0 @@
-# Copyright (C) 2021-2023, Mindee.
-
-# This program is licensed under the Apache License 2.0.
-# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
-
-from copy import deepcopy
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-from tensorflow.keras import layers
-from tensorflow.keras.models import Sequential
-
-from doctr.datasets import VOCABS
-from doctr.models.modules.layers.tensorflow import RepConvLayer
-from doctr.models.utils.tensorflow import conv_sequence
-
-from ...utils import load_pretrained_params
-
-__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
-
-default_cfgs: Dict[str, Dict[str, Any]] = {
-    "textnetfast_tiny": {
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_small": {
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_base": {
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-}
-
-
-class TextNetFast(Sequential):
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    Args:
-        stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1
-        stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2
-        stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3
-        stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4
-        include_top (bool, optional): Whether to include the classifier head. Defaults to True.
-        num_classes (int, optional): Number of output classes. Defaults to 1000.
-        cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
-    """
-
-    def __init__(
-        self,
-        stage1: List[Dict[str, Union[int, List[int]]]],
-        stage2: List[Dict[str, Union[int, List[int]]]],
-        stage3: List[Dict[str, Union[int, List[int]]]],
-        stage4: List[Dict[str, Union[int, List[int]]]],
-        include_top: bool = True,
-        num_classes: int = 1000,
-        input_shape: Tuple[int, int, int] = (32, 32, 3),
-        cfg: Optional[Dict[str, Any]] = None,
-        **kwargs: Any,
-    ) -> None:
-        _layers = [
-            *conv_sequence(
-                input_shape=input_shape, out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2
-            )
-        ]
-
-        for stage in [stage1, stage2, stage3, stage4]:
-            stage_ = Sequential([RepConvLayer(**params) for params in stage])
-            _layers.extend([stage_])
-
-        if include_top:
-            _layers.extend(
-                [layers.GlobalAveragePooling2D(), layers.Flatten(), layers.Dense(num_classes, activation=None)]
-            )
-
-        super().__init__(_layers)
-        self.cfg = cfg
-
-
-def _textnetfast(
-    arch: str,
-    pretrained: bool,
-    arch_fn,
-    **kwargs: Any,
-) -> TextNetFast:
-    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
-    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
-
-    _cfg = deepcopy(default_cfgs[arch])
-    _cfg["num_classes"] = kwargs["num_classes"]
-    _cfg["classes"] = kwargs["classes"]
-    kwargs.pop("classes")
-
-    # Build the model
-    model = arch_fn(**kwargs)
-    # Load pretrained parameters
-    if pretrained:
-        # The number of classes is not the same as the number of classes in the pretrained model =>
-        # remove the last layer weights
-        load_pretrained_params(model, default_cfgs[arch]["url"])
-
-    model.cfg = _cfg
-
-    return model
-
-
-def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_tiny
-    >>> model = textnetfast_tiny(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNet model
-    """
-
-    return _textnetfast(
-        "textnetfast_tiny",
-        pretrained,
-        TextNetFast,
-        stage1=[
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage2=[
-            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-        ],
-        stage3=[
-            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-        ],
-        stage4=[
-            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
-        ],
-        **kwargs,
-    )
-
-
-def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector
-    with Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_small
-    >>> model = textnetfast_small(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_small",
-        pretrained,
-        TextNetFast,
-        stage1=[
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-        ],
-        stage2=[
-            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage3=[
-            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage4=[
-            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-        ],
-        **kwargs,
-    )
-
-
-def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_base
-    >>> model = textnetfast_base(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_base",
-        pretrained,
-        TextNetFast,
-        stage1=[
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage2=[
-            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage3=[
-            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-        ],
-        stage4=[
-            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-        ],
-        **kwargs,
-    )
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index 2bac0fbfe7..bd6991a2c9 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -2,6 +2,7 @@
 
 import torch
 import torch.nn as nn
+import numpy as np
 
 __all__ = ["RepConvLayer"]
 
@@ -9,11 +10,12 @@
 class RepConvLayer(nn.Module):
     """Reparameterized Convolutional Layer"""
 
-    def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Any], **kwargs: Any) -> None:
+    def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Any],groups: int =1, deploy: bool = False,  **kwargs: Any) -> None:
         super().__init__()
 
         kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
-
+        
+        
         dilation = kwargs.get("dilation", 1)
         stride = kwargs.get("stride", 1)
         kwargs.pop("padding", None)
@@ -60,6 +62,14 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Any],
 
         self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
 
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.dilation = dilation
+        self.groups = groups
+        self.deploy = deploy
+        
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         main_outputs = self.main_bn(self.main_conv(x))
 
@@ -79,3 +89,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
             id_out = 0
 
         return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
+
+    def _identity_to_conv(self, identity):
+        if identity is None:
+            return 0, 0
+        assert isinstance(identity, nn.BatchNorm2d)
+        if not hasattr(self, 'id_tensor'):
+            input_dim = self.in_channels // self.groups
+            kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
+            for i in range(self.in_channels):
+                kernel_value[i, i % input_dim, 0, 0] = 1
+            id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
+            self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
+        kernel = self.id_tensor
+        running_mean = identity.running_mean
+        running_var = identity.running_var
+        gamma = identity.weight
+        beta = identity.bias
+        eps = identity.eps
+        std = (running_var + eps).sqrt()
+        t = (gamma / std).reshape(-1, 1, 1, 1)
+        return kernel * t, beta - running_mean * gamma / std
+
+    def _fuse_bn_tensor(self, conv, bn):
+        kernel = conv.weight
+        kernel = self._pad_to_mxn_tensor(kernel)
+        running_mean = bn.running_mean
+        running_var = bn.running_var
+        gamma = bn.weight
+        beta = bn.bias
+        eps = bn.eps
+        std = (running_var + eps).sqrt()
+        t = (gamma / std).reshape(-1, 1, 1, 1)
+        return kernel * t, beta - running_mean * gamma / std
+
+    def get_equivalent_kernel_bias(self):
+        kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.main_conv, self.main_bn)
+        if self.ver_conv is not None:
+            kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
+        else:
+            kernel_mx1, bias_mx1 = 0, 0
+        if self.hor_conv is not None:
+            kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
+        else:
+            kernel_1xn, bias_1xn = 0, 0
+        kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
+        kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
+        bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
+        return kernel_mxn, bias_mxn
+
+    def _pad_to_mxn_tensor(self, kernel):
+        kernel_height, kernel_width = self.kernel_size
+        height, width = kernel.shape[2:]
+        pad_left_right = (kernel_width - width) // 2
+        pad_top_down = (kernel_height - height) // 2
+        return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right,
+                                                pad_top_down, pad_top_down])
+
+    def switch_to_deploy(self):
+        if hasattr(self, 'fused_conv'):
+            return
+        kernel, bias = self.get_equivalent_kernel_bias()
+        self.fused_conv = nn.Conv2d(in_channels=self.main_conv.in_channels,
+                                    out_channels=self.main_conv.out_channels,
+                                    kernel_size=self.main_conv.kernel_size, stride=self.main_conv.stride,
+                                    padding=self.main_conv.padding, dilation=self.main_conv.dilation,
+                                    groups=self.main_conv.groups, bias=True)
+        self.fused_conv.weight.data = kernel
+        self.fused_conv.bias.data = bias
+        self.deploy = True
+        for para in self.parameters():
+            para.detach_()
+        for attr in ['main_conv', 'main_bn', 'ver_conv', 'ver_bn', 'hor_conv', 'hor_bn']:
+            if hasattr(self, attr):
+                self.__delattr__(attr)
+
+        if hasattr(self, 'rbr_identity'):
+            self.__delattr__('rbr_identity')
+
+    def switch_to_test(self):
+        kernel, bias = self.get_equivalent_kernel_bias()
+        self.fused_conv = nn.Conv2d(in_channels=self.main_conv.in_channels,
+                                    out_channels=self.main_conv.out_channels,
+                                    kernel_size=self.main_conv.kernel_size, stride=self.main_conv.stride,
+                                    padding=self.main_conv.padding, dilation=self.main_conv.dilation,
+                                    groups=self.main_conv.groups, bias=True)
+        self.fused_conv.weight.data = kernel
+        self.fused_conv.bias.data = bias
+        for para in self.fused_conv.parameters():
+            para.detach_()
+        self.deploy = True
+
+    def switch_to_train(self):
+        if hasattr(self, 'fused_conv'):
+            self.__delattr__('fused_conv')
+        self.deploy = False
+
+    @staticmethod
+    def is_zero_layer():
+        return False
+
+    @property
+    def module_str(self):
+        return 'Rep_%dx%d' % (self.kernel_size[0], self.kernel_size[1])
+
+    @property
+    def config(self):
+        return {'name': RepConvLayer.__name__,
+                'in_channels': self.in_channels,
+                'out_channels': self.out_channels,
+                'kernel_size': self.kernel_size,
+                'stride': self.stride,
+                'dilation': self.dilation,
+                'groups': self.groups}
+
+    @staticmethod
+    def build_from_config(config):
+        return RepConvLayer(**config)
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index f9a603644d..0d057971c5 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -154,8 +154,14 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
 
 def rep_model_convert(model: torch.nn.Module):
     for module in model.modules():
-        if hasattr(module, "switch_to_deploy"):
-            module.switch_to_deploy()
+        if hasattr(module, "switch_to_test"):
+            module.switch_to_test()  # type ignore[operator]
+    return model
+
+def rep_model_unconvert(model: torch.nn.Module):
+    for module in model.modules():
+        if hasattr(module, "switch_to_train"):
+            module.switch_to_train()  # type ignore[operator]
     return model
 
 

From 68147189db52caf1ac7d489a71a6964860a2d161 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 4 Sep 2023 23:20:13 +0200
Subject: [PATCH 29/44] [skip ci] changing eval and train method to switch
 repconvlayer setting test method to deploy model

---
 doctr/models/artefacts/face.py                |  4 +-
 .../classification/textnet_fast/__init__.py   |  2 +-
 .../classification/textnet_fast/pytorch.py    | 16 +++-
 .../classification/textnet_fast/tensorflow.py |  0
 doctr/models/modules/layers/pytorch.py        | 90 +++++++++++--------
 doctr/models/utils/pytorch.py                 | 12 ++-
 6 files changed, 80 insertions(+), 44 deletions(-)
 create mode 100644 doctr/models/classification/textnet_fast/tensorflow.py

diff --git a/doctr/models/artefacts/face.py b/doctr/models/artefacts/face.py
index f79200a07e..7f8d2d4d74 100644
--- a/doctr/models/artefacts/face.py
+++ b/doctr/models/artefacts/face.py
@@ -28,9 +28,7 @@ def __init__(
     ) -> None:
         self.n_faces = n_faces
         # Instantiate classifier
-        self.detector = cv2.CascadeClassifier(
-            cv2.data.haarcascades + "haarcascade_frontalface_default.xml"  # type: ignore[attr-defined]
-        )
+        self.detector = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
 
     def extra_repr(self) -> str:
         return f"n_faces={self.n_faces}"
diff --git a/doctr/models/classification/textnet_fast/__init__.py b/doctr/models/classification/textnet_fast/__init__.py
index c7110f5669..64556e403a 100644
--- a/doctr/models/classification/textnet_fast/__init__.py
+++ b/doctr/models/classification/textnet_fast/__init__.py
@@ -3,4 +3,4 @@
 if is_tf_available():
     from .tensorflow import *
 elif is_torch_available():
-    from .pytorch import *  # type: ignore[assignment]
+    from .pytorch import *
diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 5e96b39c04..4e3256ec14 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -12,7 +12,13 @@
 from doctr.datasets import VOCABS
 from doctr.models.modules.layers.pytorch import RepConvLayer
 from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
-from doctr.models.utils.pytorch import fuse_module, rep_model_convert, unfuse_module, rep_model_unconvert
+from doctr.models.utils.pytorch import (
+    fuse_module,
+    rep_model_convert,
+    rep_model_convert_deploy,
+    rep_model_unconvert,
+    unfuse_module,
+)
 
 from ...utils import load_pretrained_params
 
@@ -99,13 +105,19 @@ def eval(self):
         self.training = False
 
     def train(self):
-
         self = unfuse_module(self)
         self = rep_model_unconvert(self)
         for param in self.parameters():
             param.requires_grad = True
         self.training = True
 
+    def test(self):
+        self = rep_model_convert_deploy(self)
+        self = fuse_module(self)
+        for param in self.parameters():
+            param.requires_grad = False
+        self.training = False
+
 
 def _textnetfast(
     arch: str,
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index bd6991a2c9..c54be1bd07 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -1,8 +1,8 @@
 from typing import Any, Union
 
+import numpy as np
 import torch
 import torch.nn as nn
-import numpy as np
 
 __all__ = ["RepConvLayer"]
 
@@ -10,12 +10,19 @@
 class RepConvLayer(nn.Module):
     """Reparameterized Convolutional Layer"""
 
-    def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Any],groups: int =1, deploy: bool = False,  **kwargs: Any) -> None:
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: Union[Any],
+        groups: int = 1,
+        deploy: bool = False,
+        **kwargs: Any,
+    ) -> None:
         super().__init__()
 
         kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
-        
-        
+
         dilation = kwargs.get("dilation", 1)
         stride = kwargs.get("stride", 1)
         kwargs.pop("padding", None)
@@ -69,7 +76,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[Any],
         self.dilation = dilation
         self.groups = groups
         self.deploy = deploy
-        
+
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         main_outputs = self.main_bn(self.main_conv(x))
 
@@ -94,7 +101,7 @@ def _identity_to_conv(self, identity):
         if identity is None:
             return 0, 0
         assert isinstance(identity, nn.BatchNorm2d)
-        if not hasattr(self, 'id_tensor'):
+        if not hasattr(self, "id_tensor"):
             input_dim = self.in_channels // self.groups
             kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
             for i in range(self.in_channels):
@@ -107,7 +114,7 @@ def _identity_to_conv(self, identity):
         gamma = identity.weight
         beta = identity.bias
         eps = identity.eps
-        std = (running_var + eps).sqrt()
+        std = (running_var + eps).sqrt()  # type: ignore
         t = (gamma / std).reshape(-1, 1, 1, 1)
         return kernel * t, beta - running_mean * gamma / std
 
@@ -143,46 +150,55 @@ def _pad_to_mxn_tensor(self, kernel):
         height, width = kernel.shape[2:]
         pad_left_right = (kernel_width - width) // 2
         pad_top_down = (kernel_height - height) // 2
-        return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right,
-                                                pad_top_down, pad_top_down])
+        return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down])
 
     def switch_to_deploy(self):
-        if hasattr(self, 'fused_conv'):
+        if hasattr(self, "fused_conv"):
             return
         kernel, bias = self.get_equivalent_kernel_bias()
-        self.fused_conv = nn.Conv2d(in_channels=self.main_conv.in_channels,
-                                    out_channels=self.main_conv.out_channels,
-                                    kernel_size=self.main_conv.kernel_size, stride=self.main_conv.stride,
-                                    padding=self.main_conv.padding, dilation=self.main_conv.dilation,
-                                    groups=self.main_conv.groups, bias=True)
+        self.fused_conv = nn.Conv2d(
+            in_channels=self.main_conv.in_channels,
+            out_channels=self.main_conv.out_channels,
+            kernel_size=self.main_conv.kernel_size,  # type: ignore
+            stride=self.main_conv.stride,  # type: ignore
+            padding=self.main_conv.padding,  # type: ignore
+            dilation=self.main_conv.dilation,  # type: ignore
+            groups=self.main_conv.groups,
+            bias=True,
+        )
         self.fused_conv.weight.data = kernel
-        self.fused_conv.bias.data = bias
+        self.fused_conv.bias.data = bias  # type: ignore
         self.deploy = True
         for para in self.parameters():
             para.detach_()
-        for attr in ['main_conv', 'main_bn', 'ver_conv', 'ver_bn', 'hor_conv', 'hor_bn']:
+        for attr in ["main_conv", "main_bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
             if hasattr(self, attr):
                 self.__delattr__(attr)
 
-        if hasattr(self, 'rbr_identity'):
-            self.__delattr__('rbr_identity')
+        if hasattr(self, "rbr_identity"):
+            self.__delattr__("rbr_identity")
 
     def switch_to_test(self):
         kernel, bias = self.get_equivalent_kernel_bias()
-        self.fused_conv = nn.Conv2d(in_channels=self.main_conv.in_channels,
-                                    out_channels=self.main_conv.out_channels,
-                                    kernel_size=self.main_conv.kernel_size, stride=self.main_conv.stride,
-                                    padding=self.main_conv.padding, dilation=self.main_conv.dilation,
-                                    groups=self.main_conv.groups, bias=True)
-        self.fused_conv.weight.data = kernel
-        self.fused_conv.bias.data = bias
+        self.fused_conv = nn.Conv2d(
+            in_channels=self.main_conv.in_channels,
+            out_channels=self.main_conv.out_channels,
+            kernel_size=self.main_conv.kernel_size,  # type: ignore
+            stride=self.main_conv.stride,  # type: ignore
+            padding=self.main_conv.padding,  # type: ignore
+            dilation=self.main_conv.dilation,  # type: ignore
+            groups=self.main_conv.groups,
+            bias=True,
+        )
+        self.fused_conv.weight.data = kernel  # type ignore[operator]
+        self.fused_conv.bias.data = bias  # type: ignore
         for para in self.fused_conv.parameters():
             para.detach_()
         self.deploy = True
 
     def switch_to_train(self):
-        if hasattr(self, 'fused_conv'):
-            self.__delattr__('fused_conv')
+        if hasattr(self, "fused_conv"):
+            self.__delattr__("fused_conv")
         self.deploy = False
 
     @staticmethod
@@ -191,17 +207,19 @@ def is_zero_layer():
 
     @property
     def module_str(self):
-        return 'Rep_%dx%d' % (self.kernel_size[0], self.kernel_size[1])
+        return "Rep_%dx%d" % (self.kernel_size[0], self.kernel_size[1])
 
     @property
     def config(self):
-        return {'name': RepConvLayer.__name__,
-                'in_channels': self.in_channels,
-                'out_channels': self.out_channels,
-                'kernel_size': self.kernel_size,
-                'stride': self.stride,
-                'dilation': self.dilation,
-                'groups': self.groups}
+        return {
+            "name": RepConvLayer.__name__,
+            "in_channels": self.in_channels,
+            "out_channels": self.out_channels,
+            "kernel_size": self.kernel_size,
+            "stride": self.stride,
+            "dilation": self.dilation,
+            "groups": self.groups,
+        }
 
     @staticmethod
     def build_from_config(config):
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index 0d057971c5..e438498ad4 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -155,13 +155,21 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T
 def rep_model_convert(model: torch.nn.Module):
     for module in model.modules():
         if hasattr(module, "switch_to_test"):
-            module.switch_to_test()  # type ignore[operator]
+            module.switch_to_test()  # type: ignore
     return model
 
+
 def rep_model_unconvert(model: torch.nn.Module):
     for module in model.modules():
         if hasattr(module, "switch_to_train"):
-            module.switch_to_train()  # type ignore[operator]
+            module.switch_to_train()  # type: ignore
+    return model
+
+
+def rep_model_convert_deploy(model: torch.nn.Module):
+    for module in model.modules():
+        if hasattr(module, "switch_to_deploy"):
+            module.switch_to_deploy()  # type: ignore
     return model
 
 

From eac952ea03edcbca2e56a262801a24500615bd78 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 5 Sep 2023 12:03:31 +0200
Subject: [PATCH 30/44] [skip ci] TextNetFast pytorch model done

---
 .../classification/textnet_fast/pytorch.py    | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py
index 4e3256ec14..bce89525d8 100644
--- a/doctr/models/classification/textnet_fast/pytorch.py
+++ b/doctr/models/classification/textnet_fast/pytorch.py
@@ -97,26 +97,29 @@ def __init__(
                 nn.init.constant_(m.weight, 1)
                 nn.init.constant_(m.bias, 0)
 
-    def eval(self):
+    def eval(self, mode=False):
         self = rep_model_convert(self)
         self = fuse_module(self)
         for param in self.parameters():
-            param.requires_grad = False
-        self.training = False
+            param.requires_grad = mode
+        self.training = mode
+        return self
 
-    def train(self):
+    def train(self, mode=True):
         self = unfuse_module(self)
         self = rep_model_unconvert(self)
         for param in self.parameters():
-            param.requires_grad = True
-        self.training = True
+            param.requires_grad = mode
+        self.training = mode
+        return self
 
-    def test(self):
+    def test(self, mode=False):
         self = rep_model_convert_deploy(self)
         self = fuse_module(self)
         for param in self.parameters():
-            param.requires_grad = False
-        self.training = False
+            param.requires_grad = mode
+        self.training = mode
+        return self
 
 
 def _textnetfast(

From 0ef122f1a8f0475c078944416157918b2e750004 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Tue, 5 Sep 2023 17:52:42 +0200
Subject: [PATCH 31/44] TextNetFast model implemented in torch

---
 doctr/models/classification/textnet_fast/tensorflow.py | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 delete mode 100644 doctr/models/classification/textnet_fast/tensorflow.py

diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
deleted file mode 100644
index e69de29bb2..0000000000

From dcd2ececeb4db69c786ae29864579d05fa1695ac Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Fri, 8 Sep 2023 18:18:16 +0200
Subject: [PATCH 32/44] adding textNetFast tensorflow implementation

---
 .../classification/textnet_fast/__init__.py   |   2 +-
 .../classification/textnet_fast/tensorflow.py | 309 ++++++++++++++++++
 doctr/models/modules/layers/tensorflow.py     |   8 +-
 doctr/models/utils/tensorflow.py              | 111 ++++++-
 4 files changed, 424 insertions(+), 6 deletions(-)
 create mode 100644 doctr/models/classification/textnet_fast/tensorflow.py

diff --git a/doctr/models/classification/textnet_fast/__init__.py b/doctr/models/classification/textnet_fast/__init__.py
index 64556e403a..c7110f5669 100644
--- a/doctr/models/classification/textnet_fast/__init__.py
+++ b/doctr/models/classification/textnet_fast/__init__.py
@@ -3,4 +3,4 @@
 if is_tf_available():
     from .tensorflow import *
 elif is_torch_available():
-    from .pytorch import *
+    from .pytorch import *  # type: ignore[assignment]
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
new file mode 100644
index 0000000000..fb0acecf9c
--- /dev/null
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -0,0 +1,309 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+
+from copy import deepcopy
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import tensorflow as tf
+
+from doctr.datasets import VOCABS
+from doctr.models.modules.layers.tensorflow import RepConvLayer
+from doctr.models.utils.tensorflow import (
+    conv_sequence,
+    fuse_module,
+    rep_model_convert,
+    rep_model_convert_deploy,
+    rep_model_unconvert,
+    unfuse_module,
+)
+
+from ...utils import load_pretrained_params
+
+__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+    "textnetfast_tiny": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_small": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+    "textnetfast_base": {
+        "input_shape": (3, 32, 32),
+        "classes": list(VOCABS["french"]),
+        "url": None,
+    },
+}
+
+
+class TextNetFast(tf.keras.Sequential):
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    Args:
+        stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1
+        stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2
+        stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3
+        stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4
+        include_top (bool, optional): Whether to include the classifier head. Defaults to True.
+        num_classes (int, optional): Number of output classes. Defaults to 1000.
+        cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
+    """
+
+    def __init__(
+        self,
+        stage1: List[Dict[str, Union[int, List[int]]]],
+        stage2: List[Dict[str, Union[int, List[int]]]],
+        stage3: List[Dict[str, Union[int, List[int]]]],
+        stage4: List[Dict[str, Union[int, List[int]]]],
+        include_top: bool = True,
+        num_classes: int = 1000,
+        input_shape: Optional[Tuple[int, int, int]] = None,
+        cfg: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        first_conv = tf.keras.Sequential(
+            conv_sequence(out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2)
+        )
+        _layers = [first_conv]
+
+        for stage in [stage1, stage2, stage3, stage4]:
+            stage_ = tf.keras.Sequential([RepConvLayer(**params) for params in stage])
+            _layers.extend([stage_])
+
+        if include_top:
+            classif_block = [
+                tf.keras.layers.GlobalAveragePooling2D(),
+                tf.keras.layers.Flatten(),
+                tf.keras.layers.Dense(num_classes, activation="softmax", use_bias=True, kernel_initializer="he_normal"),
+            ]
+            classif_block_ = tf.keras.Sequential(classif_block)
+            _layers.append(classif_block_)
+
+        super().__init__(_layers)
+        self.cfg = cfg
+
+    def eval(self, mode=False):
+        self = rep_model_convert(self)
+        self = fuse_module(self)
+        self.trainable = mode
+        return self
+
+    def train(self, mode=True):
+        self = unfuse_module(self)
+        self = rep_model_unconvert(self)
+        self.trainable = mode
+        return self
+
+    def test(self, mode=False):
+        self = rep_model_convert_deploy(self)
+        self = fuse_module(self)
+        self.trainable = mode
+        return self
+
+
+def _textnetfast(
+    arch: str,
+    pretrained: bool,
+    arch_fn,
+    **kwargs: Any,
+) -> TextNetFast:
+    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
+    kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
+    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
+
+    _cfg = deepcopy(default_cfgs[arch])
+    _cfg["num_classes"] = kwargs["num_classes"]
+    _cfg["classes"] = kwargs["classes"]
+    _cfg["input_shape"] = kwargs["input_shape"]
+    kwargs.pop("classes")
+
+    # Build the model
+    model = arch_fn(**kwargs)
+    # Load pretrained parameters
+    if pretrained:
+        # The number of classes is not the same as the number of classes in the pretrained model =>
+        # remove the last layer weights
+        load_pretrained_params(model, default_cfgs[arch]["url"])
+
+    return model
+
+
+def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_tiny
+    >>> model = textnetfast_tiny(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNet model
+    """
+
+    return _textnetfast(
+        "textnetfast_tiny",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage2=[
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+        ],
+        stage3=[
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+        ],
+        stage4=[
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
+        ],
+        **kwargs,
+    )
+
+
+def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_small
+    >>> model = textnetfast_small(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_small",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+        ],
+        stage2=[
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage3=[
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage4=[
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+        ],
+        **kwargs,
+    )
+
+
+def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
+    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import textnetfast_base
+    >>> model = textnetfast_base(pretrained=False)
+    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained: boolean, True if model is pretrained
+
+    Returns:
+        A TextNetFast model
+    """
+
+    return _textnetfast(
+        "textnetfast_base",
+        pretrained,
+        TextNetFast,
+        stage1=[
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage2=[
+            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
+        ],
+        stage3=[
+            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
+            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
+        ],
+        stage4=[
+            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
+            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
+        ],
+        **kwargs,
+    )
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index 364414462c..516b2b3866 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -3,6 +3,8 @@
 import tensorflow as tf
 from tensorflow.keras import layers
 
+__all__ = ["RepConvLayer"]
+
 
 class RepConvLayer(layers.Layer):
     def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1):
@@ -33,11 +35,11 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
         if kernel_size[1] != 1:
             self.ver_conv = tf.keras.Sequential(
                 [
-                    layers.ZeroPadding2D(padding=padding),
+                    layers.ZeroPadding2D(padding=(int(((kernel_size[0] - 1) * dilation) / 2), 0)),
                     layers.Conv2D(
                         filters=out_channels,
                         kernel_size=(kernel_size[0], 1),
-                        strides=(stride, 1),
+                        strides=stride,
                         dilation_rate=(dilation, 1),
                         groups=groups,
                         use_bias=False,
@@ -53,7 +55,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
         if kernel_size[0] != 1:
             self.hor_conv = tf.keras.Sequential(
                 [
-                    layers.ZeroPadding2D(padding=padding),
+                    layers.ZeroPadding2D(padding=(0, int(((kernel_size[1] - 1) * dilation) / 2))),
                     layers.Conv2D(
                         filters=out_channels,
                         kernel_size=(1, kernel_size[1]),
diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py
index 8490c09f11..3f9befe433 100644
--- a/doctr/models/utils/tensorflow.py
+++ b/doctr/models/utils/tensorflow.py
@@ -8,6 +8,7 @@
 from typing import Any, Callable, List, Optional, Tuple, Union
 from zipfile import ZipFile
 
+import numpy as np
 import tensorflow as tf
 import tf2onnx
 from tensorflow.keras import Model, layers
@@ -70,8 +71,7 @@ def conv_sequence(
 ) -> List[layers.Layer]:
     """Builds a convolutional-based layer sequence
 
-    >>> from tensorflow.keras import Sequential
-    >>> from doctr.models import conv_sequence
+    >>> from doctr.models.utils import conv_sequence
     >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3]))
 
     Args:
@@ -160,3 +160,110 @@ def export_model_to_onnx(
 
     logging.info(f"Model exported to {model_name}.zip")
     return f"{model_name}.onnx", output
+
+
+def rep_model_convert(model):
+    for layer in model.layers:
+        if hasattr(layer, "switch_to_test"):
+            layer.switch_to_test()
+    return model
+
+
+def rep_model_unconvert(model):
+    for layer in model.layers:
+        if hasattr(layer, "switch_to_train"):
+            layer.switch_to_train()
+    return model
+
+
+def rep_model_convert_deploy(model):
+    for layer in model.layers:
+        if hasattr(layer, "switch_to_deploy"):
+            layer.switch_to_deploy()
+    return model
+
+
+def fuse_conv_bn(conv, bn):
+    """During inference, the functionality of batch norm layers is turned off but
+    only the mean and variance along channels are used, which exposes the opportunity
+    to fuse it with the preceding conv layers to save computations and simplify
+    network structures."""
+    print(dir(conv))
+    conv_weights, conv_biases = conv.get_weights()
+    bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights()
+
+    if conv_biases is None:
+        conv_biases = np.zeros_like(bn_running_mean)
+
+    epsilon = bn.epsilon
+    scale_factor = bn_weights / np.sqrt(bn_running_var + epsilon)
+
+    # Reshape the scale factor to match the convolutional weights shape
+    scale_factor = scale_factor.reshape((1, 1, 1, -1))
+
+    # Update convolutional weights and biases
+    fused_conv_weights = conv_weights * scale_factor
+    fused_conv_biases = (conv_biases - bn_running_mean) * scale_factor.flatten() + bn_biases
+
+    # Setting the updated weights and biases in conv layer
+    conv.set_weights([fused_conv_weights, fused_conv_biases])
+    conv.old_weight, conv.old_biais = conv.get_weights()
+    return conv
+
+
+def fuse_module(model):
+    last_conv = None
+
+    for layer in model.layers:
+        if isinstance(layer, (tf.keras.layers.BatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization)):
+            if last_conv is None:  # only fuse BN that is after Conv
+                continue
+            # Fused Conv and BN (You would need to define fuse_conv_bn_tf)
+            fuse_conv_bn(last_conv, layer)
+            # Here you'd need to replace the last_conv layer with fused_conv
+            # in the model, and replace the current layer with an identity layer.
+            # This is non-trivial in TensorFlow as Keras models are not as
+            # dynamically modifiable as PyTorch models.
+
+        elif isinstance(layer, tf.keras.layers.Conv2D):
+            last_conv = layer
+        else:
+            # Recursively apply to nested models
+            fuse_module(layer)
+    return model
+
+
+def unfuse_conv_bn(conv, bn):
+    """During inference, the functionary of batch norm layers is turned off but
+    only the mean and var alone channels are used, which exposes the chance to
+    fuse it with the preceding conv layers to save computations and simplify
+    network structures."""
+    conv.set_weights([conv.old_weight, conv.old_biais])
+    return conv
+
+
+def unfuse_module(model):
+    last_conv = None
+
+    for i, layer in enumerate(model.layers):
+        if isinstance(layer, tf.keras.layers.Layer):
+            pass
+        else:
+            continue
+
+        if isinstance(layer, tf.keras.layers.Lambda):
+            if last_conv is None:
+                continue
+            unfused_conv, unfused_bn = unfuse_conv_bn(last_conv, layer)
+
+            # In TensorFlow, we can't modify the model in-place like in PyTorch,
+            # so you would need to create a new model with the modified layers.
+            # Here, you'd replace the last_conv layer with unfused_conv and
+            # the current layer with unfused_bn.
+
+        elif isinstance(layer, tf.keras.layers.Conv2D):
+            last_conv = layer
+        else:
+            # Recursive call for potentially nested layers (e.g., in case of a nested model)
+            unfuse_module(layer)
+    return layer

From a8ac914672fd8a2d9b053206cef9d3d363e04aea Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Sat, 9 Sep 2023 14:51:22 +0200
Subject: [PATCH 33/44] starting to solving eval mode of textnetFadt model

---
 .../classification/textnet_fast/tensorflow.py | 14 ++++-----
 doctr/models/modules/layers/tensorflow.py     | 20 ++++++-------
 doctr/models/utils/tensorflow.py              | 30 +++++++++----------
 3 files changed, 30 insertions(+), 34 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index fb0acecf9c..9528487316 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -66,13 +66,14 @@ def __init__(
         stage4: List[Dict[str, Union[int, List[int]]]],
         include_top: bool = True,
         num_classes: int = 1000,
-        input_shape: Optional[Tuple[int, int, int]] = None,
         cfg: Optional[Dict[str, Any]] = None,
+        input_shape: Optional[Tuple[int, int, int]] = None,
     ) -> None:
-        first_conv = tf.keras.Sequential(
-            conv_sequence(out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2)
-        )
-        _layers = [first_conv]
+        _layers = [
+            tf.keras.Sequential(
+                conv_sequence(out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape),
+            )
+        ]
 
         for stage in [stage1, stage2, stage3, stage4]:
             stage_ = tf.keras.Sequential([RepConvLayer(**params) for params in stage])
@@ -94,19 +95,16 @@ def eval(self, mode=False):
         self = rep_model_convert(self)
         self = fuse_module(self)
         self.trainable = mode
-        return self
 
     def train(self, mode=True):
         self = unfuse_module(self)
         self = rep_model_unconvert(self)
         self.trainable = mode
-        return self
 
     def test(self, mode=False):
         self = rep_model_convert_deploy(self)
         self = fuse_module(self)
         self.trainable = mode
-        return self
 
 
 def _textnetfast(
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index 516b2b3866..959bfe1bfe 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -27,11 +27,10 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                     use_bias=False,
                     input_shape=(None, None, in_channels),
                 ),
+                layers.BatchNormalization()
             ]
         )
 
-        self.main_bn = layers.BatchNormalization()
-
         if kernel_size[1] != 1:
             self.ver_conv = tf.keras.Sequential(
                 [
@@ -45,12 +44,12 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                         use_bias=False,
                         input_shape=(None, None, in_channels),
                     ),
+                    layers.BatchNormalization()
                 ]
             )
 
-            self.ver_bn = layers.BatchNormalization()
         else:
-            self.ver_conv, self.ver_bn = None, None
+            self.ver_conv = None
 
         if kernel_size[0] != 1:
             self.hor_conv = tf.keras.Sequential(
@@ -65,23 +64,24 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                         use_bias=False,
                         input_shape=(None, None, in_channels),
                     ),
+                    layers.BatchNormalization()
                 ]
             )
-
-            self.hor_bn = layers.BatchNormalization()
         else:
-            self.hor_conv, self.hor_bn = None, None
+            self.hor_conv = None
 
         self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
+        
+        self.layers = [self.main_conv, self.ver_conv, self.hor_conv, self.rbr_identity, self.activation]
 
     def call(
         self,
         x: tf.Tensor,
         **kwargs: Any,
     ) -> tf.Tensor:
-        main_outputs = self.main_bn(self.main_conv(x, **kwargs), **kwargs)
-        vertical_outputs = self.ver_bn(self.ver_conv(x, **kwargs), **kwargs) if self.ver_conv is not None else 0
-        horizontal_outputs = self.hor_bn(self.hor_conv(x, **kwargs), **kwargs) if self.hor_conv is not None else 0
+        main_outputs = self.main_conv(x, **kwargs)
+        vertical_outputs = self.ver_conv(x, **kwargs) if self.ver_conv is not None else 0
+        horizontal_outputs = self.hor_conv(x, **kwargs) if self.hor_conv is not None else 0
         id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
 
         p = main_outputs + vertical_outputs
diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py
index 3f9befe433..6c5de77740 100644
--- a/doctr/models/utils/tensorflow.py
+++ b/doctr/models/utils/tensorflow.py
@@ -13,6 +13,7 @@
 import tf2onnx
 from tensorflow.keras import Model, layers
 
+from doctr.models.modules.layers.tensorflow import RepConvLayer
 from doctr.utils.data import download_from_url
 
 logging.getLogger("tensorflow").setLevel(logging.DEBUG)
@@ -188,13 +189,15 @@ def fuse_conv_bn(conv, bn):
     only the mean and variance along channels are used, which exposes the opportunity
     to fuse it with the preceding conv layers to save computations and simplify
     network structures."""
-    print(dir(conv))
-    conv_weights, conv_biases = conv.get_weights()
-    bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights()
 
-    if conv_biases is None:
+    
+    bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights()
+    weights = conv.get_weights()
+    if len(weights) == 1:
+        conv_weights = weights[0]
         conv_biases = np.zeros_like(bn_running_mean)
-
+    else:
+        conv_weights, conv_biases = conv.get_weights()
     epsilon = bn.epsilon
     scale_factor = bn_weights / np.sqrt(bn_running_var + epsilon)
 
@@ -205,30 +208,25 @@ def fuse_conv_bn(conv, bn):
     fused_conv_weights = conv_weights * scale_factor
     fused_conv_biases = (conv_biases - bn_running_mean) * scale_factor.flatten() + bn_biases
 
-    # Setting the updated weights and biases in conv layer
+    conv.use_bias = True
+    conv.build(input_shape=conv.input_shape)
     conv.set_weights([fused_conv_weights, fused_conv_biases])
-    conv.old_weight, conv.old_biais = conv.get_weights()
-    return conv
+    conv.old_weight, conv.old_biais = conv_weights, conv_biases
 
 
 def fuse_module(model):
     last_conv = None
 
     for layer in model.layers:
+        print(layer)
         if isinstance(layer, (tf.keras.layers.BatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization)):
             if last_conv is None:  # only fuse BN that is after Conv
                 continue
-            # Fused Conv and BN (You would need to define fuse_conv_bn_tf)
+            print('ok')
             fuse_conv_bn(last_conv, layer)
-            # Here you'd need to replace the last_conv layer with fused_conv
-            # in the model, and replace the current layer with an identity layer.
-            # This is non-trivial in TensorFlow as Keras models are not as
-            # dynamically modifiable as PyTorch models.
-
         elif isinstance(layer, tf.keras.layers.Conv2D):
             last_conv = layer
-        else:
-            # Recursively apply to nested models
+        elif isinstance(layer, (tf.keras.Sequential, RepConvLayer)):
             fuse_module(layer)
     return model
 

From 48fb8e549925e6f162ed9defde1e5f5c5f0b06a3 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Sat, 9 Sep 2023 18:14:18 +0200
Subject: [PATCH 34/44] [skip ci] Last modification for switch train to eval
 mode for textnetfast model (not working for the moment)

---
 .../classification/textnet_fast/tensorflow.py |  4 +++-
 doctr/models/modules/layers/tensorflow.py     | 16 +++++++--------
 doctr/models/utils/pytorch.py                 |  1 -
 doctr/models/utils/tensorflow.py              | 20 +++++++++++++------
 4 files changed, 25 insertions(+), 16 deletions(-)

diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
index 9528487316..b9e4213408 100644
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ b/doctr/models/classification/textnet_fast/tensorflow.py
@@ -71,7 +71,9 @@ def __init__(
     ) -> None:
         _layers = [
             tf.keras.Sequential(
-                conv_sequence(out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape),
+                conv_sequence(
+                    out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
+                ),
             )
         ]
 
diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py
index 959bfe1bfe..1ce7bdb7c3 100644
--- a/doctr/models/modules/layers/tensorflow.py
+++ b/doctr/models/modules/layers/tensorflow.py
@@ -27,7 +27,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                     use_bias=False,
                     input_shape=(None, None, in_channels),
                 ),
-                layers.BatchNormalization()
+                layers.BatchNormalization(),
             ]
         )
 
@@ -44,7 +44,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                         use_bias=False,
                         input_shape=(None, None, in_channels),
                     ),
-                    layers.BatchNormalization()
+                    layers.BatchNormalization(),
                 ]
             )
 
@@ -64,15 +64,15 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, g
                         use_bias=False,
                         input_shape=(None, None, in_channels),
                     ),
-                    layers.BatchNormalization()
+                    layers.BatchNormalization(),
                 ]
             )
         else:
             self.hor_conv = None
 
-        self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
-        
-        self.layers = [self.main_conv, self.ver_conv, self.hor_conv, self.rbr_identity, self.activation]
+        # self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
+
+        self.layers = [i for i in [self.main_conv, self.ver_conv, self.hor_conv, self.activation] if i is not None]
 
     def call(
         self,
@@ -82,10 +82,10 @@ def call(
         main_outputs = self.main_conv(x, **kwargs)
         vertical_outputs = self.ver_conv(x, **kwargs) if self.ver_conv is not None else 0
         horizontal_outputs = self.hor_conv(x, **kwargs) if self.hor_conv is not None else 0
-        id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
+        # id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0
 
         p = main_outputs + vertical_outputs
-        q = horizontal_outputs + id_out
+        q = horizontal_outputs  # + id_out
         r = p + q
 
         return self.activation(r)
diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py
index e438498ad4..1e2b59ac1c 100644
--- a/doctr/models/utils/pytorch.py
+++ b/doctr/models/utils/pytorch.py
@@ -200,7 +200,6 @@ def fuse_module(m):
                 continue
             fused_conv = fuse_conv_bn(last_conv, child)
             m._modules[last_conv_name] = fused_conv
-            # To reduce changes, set BN as Identity instead of deleting it.
             m._modules[name] = nn.Identity()
             last_conv = None
         elif isinstance(child, nn.Conv2d):
diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py
index 6c5de77740..51f36a2243 100644
--- a/doctr/models/utils/tensorflow.py
+++ b/doctr/models/utils/tensorflow.py
@@ -190,7 +190,6 @@ def fuse_conv_bn(conv, bn):
     to fuse it with the preceding conv layers to save computations and simplify
     network structures."""
 
-    
     bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights()
     weights = conv.get_weights()
     if len(weights) == 1:
@@ -212,18 +211,27 @@ def fuse_conv_bn(conv, bn):
     conv.build(input_shape=conv.input_shape)
     conv.set_weights([fused_conv_weights, fused_conv_biases])
     conv.old_weight, conv.old_biais = conv_weights, conv_biases
+    return conv
 
 
 def fuse_module(model):
     last_conv = None
-
-    for layer in model.layers:
-        print(layer)
+    for i, layer in enumerate(model.layers):
         if isinstance(layer, (tf.keras.layers.BatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization)):
             if last_conv is None:  # only fuse BN that is after Conv
                 continue
-            print('ok')
-            fuse_conv_bn(last_conv, layer)
+            fuse_conv = fuse_conv_bn(last_conv, layer)
+            new_layer = tf.keras.layers.Lambda(lambda x: x)
+            model.layers[i] = new_layer
+
+            setattr(layer, layer.name, new_layer)
+            print(last_conv.name)
+            print(fuse_conv.name)
+            print(layer.name)
+            print(new_layer.name)
+            print(model.layers[i].name)
+            print(model.layers[i])
+            print()
         elif isinstance(layer, tf.keras.layers.Conv2D):
             last_conv = layer
         elif isinstance(layer, (tf.keras.Sequential, RepConvLayer)):

From 77d78b6d51d55c4e2ae0a8d78da37ed6f1f54194 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Sun, 10 Sep 2023 15:11:43 +0200
Subject: [PATCH 35/44] [skip ci] dleting tensorflow textnetfast code for
 futher integreation

---
 .../classification/textnet_fast/__init__.py   |   2 +-
 .../classification/textnet_fast/tensorflow.py | 309 ------------------
 2 files changed, 1 insertion(+), 310 deletions(-)
 delete mode 100644 doctr/models/classification/textnet_fast/tensorflow.py

diff --git a/doctr/models/classification/textnet_fast/__init__.py b/doctr/models/classification/textnet_fast/__init__.py
index c7110f5669..64556e403a 100644
--- a/doctr/models/classification/textnet_fast/__init__.py
+++ b/doctr/models/classification/textnet_fast/__init__.py
@@ -3,4 +3,4 @@
 if is_tf_available():
     from .tensorflow import *
 elif is_torch_available():
-    from .pytorch import *  # type: ignore[assignment]
+    from .pytorch import *
diff --git a/doctr/models/classification/textnet_fast/tensorflow.py b/doctr/models/classification/textnet_fast/tensorflow.py
deleted file mode 100644
index b9e4213408..0000000000
--- a/doctr/models/classification/textnet_fast/tensorflow.py
+++ /dev/null
@@ -1,309 +0,0 @@
-# Copyright (C) 2021-2023, Mindee.
-
-# This program is licensed under the Apache License 2.0.
-# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
-
-
-from copy import deepcopy
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import tensorflow as tf
-
-from doctr.datasets import VOCABS
-from doctr.models.modules.layers.tensorflow import RepConvLayer
-from doctr.models.utils.tensorflow import (
-    conv_sequence,
-    fuse_module,
-    rep_model_convert,
-    rep_model_convert_deploy,
-    rep_model_unconvert,
-    unfuse_module,
-)
-
-from ...utils import load_pretrained_params
-
-__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]
-
-default_cfgs: Dict[str, Dict[str, Any]] = {
-    "textnetfast_tiny": {
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_small": {
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-    "textnetfast_base": {
-        "input_shape": (3, 32, 32),
-        "classes": list(VOCABS["french"]),
-        "url": None,
-    },
-}
-
-
-class TextNetFast(tf.keras.Sequential):
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    Args:
-        stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1
-        stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2
-        stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3
-        stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4
-        include_top (bool, optional): Whether to include the classifier head. Defaults to True.
-        num_classes (int, optional): Number of output classes. Defaults to 1000.
-        cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
-    """
-
-    def __init__(
-        self,
-        stage1: List[Dict[str, Union[int, List[int]]]],
-        stage2: List[Dict[str, Union[int, List[int]]]],
-        stage3: List[Dict[str, Union[int, List[int]]]],
-        stage4: List[Dict[str, Union[int, List[int]]]],
-        include_top: bool = True,
-        num_classes: int = 1000,
-        cfg: Optional[Dict[str, Any]] = None,
-        input_shape: Optional[Tuple[int, int, int]] = None,
-    ) -> None:
-        _layers = [
-            tf.keras.Sequential(
-                conv_sequence(
-                    out_channels=64, activation="relu", bn=True, kernel_size=3, strides=2, input_shape=input_shape
-                ),
-            )
-        ]
-
-        for stage in [stage1, stage2, stage3, stage4]:
-            stage_ = tf.keras.Sequential([RepConvLayer(**params) for params in stage])
-            _layers.extend([stage_])
-
-        if include_top:
-            classif_block = [
-                tf.keras.layers.GlobalAveragePooling2D(),
-                tf.keras.layers.Flatten(),
-                tf.keras.layers.Dense(num_classes, activation="softmax", use_bias=True, kernel_initializer="he_normal"),
-            ]
-            classif_block_ = tf.keras.Sequential(classif_block)
-            _layers.append(classif_block_)
-
-        super().__init__(_layers)
-        self.cfg = cfg
-
-    def eval(self, mode=False):
-        self = rep_model_convert(self)
-        self = fuse_module(self)
-        self.trainable = mode
-
-    def train(self, mode=True):
-        self = unfuse_module(self)
-        self = rep_model_unconvert(self)
-        self.trainable = mode
-
-    def test(self, mode=False):
-        self = rep_model_convert_deploy(self)
-        self = fuse_module(self)
-        self.trainable = mode
-
-
-def _textnetfast(
-    arch: str,
-    pretrained: bool,
-    arch_fn,
-    **kwargs: Any,
-) -> TextNetFast:
-    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
-    kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
-    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
-
-    _cfg = deepcopy(default_cfgs[arch])
-    _cfg["num_classes"] = kwargs["num_classes"]
-    _cfg["classes"] = kwargs["classes"]
-    _cfg["input_shape"] = kwargs["input_shape"]
-    kwargs.pop("classes")
-
-    # Build the model
-    model = arch_fn(**kwargs)
-    # Load pretrained parameters
-    if pretrained:
-        # The number of classes is not the same as the number of classes in the pretrained model =>
-        # remove the last layer weights
-        load_pretrained_params(model, default_cfgs[arch]["url"])
-
-    return model
-
-
-def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_tiny
-    >>> model = textnetfast_tiny(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNet model
-    """
-
-    return _textnetfast(
-        "textnetfast_tiny",
-        pretrained,
-        TextNetFast,
-        stage1=[
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage2=[
-            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-        ],
-        stage3=[
-            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-        ],
-        stage4=[
-            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
-        ],
-        **kwargs,
-    )
-
-
-def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_small
-    >>> model = textnetfast_small(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_small",
-        pretrained,
-        TextNetFast,
-        stage1=[
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-        ],
-        stage2=[
-            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage3=[
-            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage4=[
-            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-        ],
-        **kwargs,
-    )
-
-
-def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
-    """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
-    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
-    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
-
-    >>> import torch
-    >>> from doctr.models import textnetfast_base
-    >>> model = textnetfast_base(pretrained=False)
-    >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
-    >>> out = model(input_tensor)
-
-    Args:
-        pretrained: boolean, True if model is pretrained
-
-    Returns:
-        A TextNetFast model
-    """
-
-    return _textnetfast(
-        "textnetfast_base",
-        pretrained,
-        TextNetFast,
-        stage1=[
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage2=[
-            {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
-        ],
-        stage3=[
-            {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
-            {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
-        ],
-        stage4=[
-            {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
-            {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
-        ],
-        **kwargs,
-    )

From d7b9ca2bc4902254ff6e72dc4e3eaafa1f9edaa4 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Sun, 10 Sep 2023 16:01:54 +0200
Subject: [PATCH 36/44] [skip ci] first commit on adding neck,head and fast
 model in torch

---
 doctr/models/detection/fast/__init__.py |   6 +
 doctr/models/detection/fast/base.py     | 249 ++++++++++++++++++++++++
 doctr/models/detection/fast/pytorch.py  | 235 ++++++++++++++++++++++
 3 files changed, 490 insertions(+)
 create mode 100644 doctr/models/detection/fast/__init__.py
 create mode 100644 doctr/models/detection/fast/base.py
 create mode 100644 doctr/models/detection/fast/pytorch.py

diff --git a/doctr/models/detection/fast/__init__.py b/doctr/models/detection/fast/__init__.py
new file mode 100644
index 0000000000..c7110f5669
--- /dev/null
+++ b/doctr/models/detection/fast/__init__.py
@@ -0,0 +1,6 @@
+from doctr.file_utils import is_tf_available, is_torch_available
+
+if is_tf_available():
+    from .tensorflow import *
+elif is_torch_available():
+    from .pytorch import *  # type: ignore[assignment]
diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py
new file mode 100644
index 0000000000..02c3936774
--- /dev/null
+++ b/doctr/models/detection/fast/base.py
@@ -0,0 +1,249 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
+
+from typing import Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+from doctr.models.core import BaseModel
+
+from ..core import DetectionPostProcessor
+
+__all__ = ["_LinkNet", "LinkNetPostProcessor"]
+
+
+class LinkNetPostProcessor(DetectionPostProcessor):
+    """Implements a post processor for LinkNet model.
+
+    Args:
+        bin_thresh: threshold used to binzarized p_map at inference time
+        box_thresh: minimal objectness score to consider a box
+        assume_straight_pages: whether the inputs were expected to have horizontal text elements
+    """
+
+    def __init__(
+        self,
+        bin_thresh: float = 0.1,
+        box_thresh: float = 0.1,
+        assume_straight_pages: bool = True,
+    ) -> None:
+        super().__init__(box_thresh, bin_thresh, assume_straight_pages)
+        self.unclip_ratio = 1.2
+
+    def polygon_to_box(
+        self,
+        points: np.ndarray,
+    ) -> np.ndarray:
+        """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon
+
+        Args:
+            points: The first parameter.
+
+        Returns:
+            a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle)
+        """
+        if not self.assume_straight_pages:
+            # Compute the rectangle polygon enclosing the raw polygon
+            rect = cv2.minAreaRect(points)
+            points = cv2.boxPoints(rect)
+            # Add 1 pixel to correct cv2 approx
+            area = (rect[1][0] + 1) * (1 + rect[1][1])
+            length = 2 * (rect[1][0] + rect[1][1]) + 2
+        else:
+            poly = Polygon(points)
+            area = poly.area
+            length = poly.length
+        distance = area * self.unclip_ratio / length  # compute distance to expand polygon
+        offset = pyclipper.PyclipperOffset()
+        offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+        _points = offset.Execute(distance)
+        # Take biggest stack of points
+        idx = 0
+        if len(_points) > 1:
+            max_size = 0
+            for _idx, p in enumerate(_points):
+                if len(p) > max_size:
+                    idx = _idx
+                    max_size = len(p)
+            # We ensure that _points can be correctly casted to a ndarray
+            _points = [_points[idx]]
+        expanded_points: np.ndarray = np.asarray(_points)  # expand polygon
+        if len(expanded_points) < 1:
+            return None  # type: ignore[return-value]
+        return (
+            cv2.boundingRect(expanded_points)
+            if self.assume_straight_pages
+            else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0)
+        )
+
+    def bitmap_to_boxes(
+        self,
+        pred: np.ndarray,
+        bitmap: np.ndarray,
+    ) -> np.ndarray:
+        """Compute boxes from a bitmap/pred_map: find connected components then filter boxes
+
+        Args:
+            pred: Pred map from differentiable linknet output
+            bitmap: Bitmap map computed from pred (binarized)
+            angle_tol: Comparison tolerance of the angle with the median angle across the page
+            ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop
+
+        Returns:
+            np tensor boxes for the bitmap, each box is a 6-element list
+                containing x, y, w, h, alpha, score for the box
+        """
+        height, width = bitmap.shape[:2]
+        boxes: List[Union[np.ndarray, List[float]]] = []
+        # get contours from connected components on the bitmap
+        contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+        for contour in contours:
+            # Check whether smallest enclosing bounding box is not too small
+            if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2):
+                continue
+            # Compute objectness
+            if self.assume_straight_pages:
+                x, y, w, h = cv2.boundingRect(contour)
+                points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]])
+                score = self.box_score(pred, points, assume_straight_pages=True)
+            else:
+                score = self.box_score(pred, contour, assume_straight_pages=False)
+
+            if score < self.box_thresh:  # remove polygons with a weak objectness
+                continue
+
+            if self.assume_straight_pages:
+                _box = self.polygon_to_box(points)
+            else:
+                _box = self.polygon_to_box(np.squeeze(contour))
+
+            if self.assume_straight_pages:
+                # compute relative polygon to get rid of img shape
+                x, y, w, h = _box
+                xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height
+                boxes.append([xmin, ymin, xmax, ymax, score])
+            else:
+                # compute relative box to get rid of img shape
+                _box[:, 0] /= width
+                _box[:, 1] /= height
+                boxes.append(_box)
+
+        if not self.assume_straight_pages:
+            return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
+        else:
+            return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
+
+
+class _LinkNet(BaseModel):
+    """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
+    <https://arxiv.org/pdf/1707.03718.pdf>`_.
+
+    Args:
+        out_chan: number of channels for the output
+    """
+
+    min_size_box: int = 3
+    assume_straight_pages: bool = True
+    shrink_ratio = 0.5
+
+    def build_target(
+        self,
+        target: List[Dict[str, np.ndarray]],
+        output_shape: Tuple[int, int, int],
+        channels_last: bool = True,
+    ) -> Tuple[np.ndarray, np.ndarray]:
+        """Build the target, and it's mask to be used from loss computation.
+
+        Args:
+            target: target coming from dataset
+            output_shape: shape of the output of the model without batch_size
+            channels_last: whether channels are last or not
+
+        Returns:
+            the new formatted target and the mask
+        """
+
+        if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
+            raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
+        if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
+            raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
+
+        h: int
+        w: int
+        if channels_last:
+            h, w, num_classes = output_shape
+        else:
+            num_classes, h, w = output_shape
+        target_shape = (len(target), num_classes, h, w)
+
+        seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
+        seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
+
+        for idx, tgt in enumerate(target):
+            for class_idx, _tgt in enumerate(tgt.values()):
+                # Draw each polygon on gt
+                if _tgt.shape[0] == 0:
+                    # Empty image, full masked
+                    seg_mask[idx, class_idx] = False
+
+                # Absolute bounding boxes
+                abs_boxes = _tgt.copy()
+
+                if abs_boxes.ndim == 3:
+                    abs_boxes[:, :, 0] *= w
+                    abs_boxes[:, :, 1] *= h
+                    polys = abs_boxes
+                    boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
+                    abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
+                else:
+                    abs_boxes[:, [0, 2]] *= w
+                    abs_boxes[:, [1, 3]] *= h
+                    abs_boxes = abs_boxes.round().astype(np.int32)
+                    polys = np.stack(
+                        [
+                            abs_boxes[:, [0, 1]],
+                            abs_boxes[:, [0, 3]],
+                            abs_boxes[:, [2, 3]],
+                            abs_boxes[:, [2, 1]],
+                        ],
+                        axis=1,
+                    )
+                    boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
+
+                for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
+                    # Mask boxes that are too small
+                    if box_size < self.min_size_box:
+                        seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+                        continue
+
+                    # Negative shrink for gt, as described in paper
+                    polygon = Polygon(poly)
+                    distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
+                    subject = [tuple(coor) for coor in poly]
+                    padding = pyclipper.PyclipperOffset()
+                    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
+                    shrunken = padding.Execute(-distance)
+
+                    # Draw polygon on gt if it is valid
+                    if len(shrunken) == 0:
+                        seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+                        continue
+                    shrunken = np.array(shrunken[0]).reshape(-1, 2)
+                    if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
+                        seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
+                        continue
+                    cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1)
+
+        # Don't forget to switch back to channel last if Tensorflow is used
+        if channels_last:
+            seg_target = seg_target.transpose((0, 2, 3, 1))
+            seg_mask = seg_mask.transpose((0, 2, 3, 1))
+
+        return seg_target, seg_mask
diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
new file mode 100644
index 0000000000..9788742ef2
--- /dev/null
+++ b/doctr/models/detection/fast/pytorch.py
@@ -0,0 +1,235 @@
+# Copyright (C) 2021-2023, Mindee.
+
+# This program is licensed under the Apache License 2.0.
+# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
+
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+import torch.nn as nn
+
+from doctr.file_utils import CLASS_NAME
+from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny, textnetfast_small, textnetfast_base
+from ...utils import load_pretrained_params
+from .base import LinkNetPostProcessor, _LinkNet
+from models.loss import build_loss
+
+__all__ = ["fast_tiny", "fast_small", "fast_base"]
+
+
+default_cfgs: Dict[str, Dict[str, Any]] = {
+    "fast_tiny": {
+        "input_shape": (3, 1024, 1024),
+        "url": None,
+    },
+    "fast_small": {
+        "input_shape": (3, 1024, 1024),
+        "url": None,
+    },
+    "fast_base": {
+        "input_shape": (3, 1024, 1024),
+        "url": None,
+    },
+}
+# reimplement FAST and DETECTION_HEAD
+# NECK AND TEXTNET READY
+
+class FAST(nn.Module):
+    def __init__(self, backbone, neck, detection_head):
+        super(FAST, self).__init__()
+        self.backbone = backbone # okay
+        self.neck = neck #okay
+        self.det_head = detection_head
+        
+    def _upsample(self, x, size, scale=1):
+        _, _, H, W = size
+        return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear')
+
+    def forward(self, imgs, gt_texts=None, gt_kernels=None, training_masks=None,
+                gt_instances=None, img_metas=None, cfg=None):
+        outputs = dict()
+
+        if not self.training:
+            torch.cuda.synchronize()
+            start = time.time()
+
+        # backbone
+        f = self.backbone(imgs)
+
+        if not self.training:
+            torch.cuda.synchronize()
+            outputs.update(dict(
+                backbone_time=time.time() - start
+            ))
+            start = time.time()
+
+        # reduce channel
+        f = self.neck(f)
+        
+        if not self.training:
+            torch.cuda.synchronize()
+            outputs.update(dict(
+                neck_time=time.time() - start
+            ))
+            start = time.time()
+
+        # detection
+        det_out = self.det_head(f)
+
+        if not self.training:
+            torch.cuda.synchronize()
+            outputs.update(dict(
+                det_head_time=time.time() - start
+            ))
+
+        if self.training:
+            det_out = self._upsample(det_out, imgs.size(), scale=1)
+            det_loss = self.det_head.loss(det_out, gt_texts, gt_kernels, training_masks, gt_instances)
+            outputs.update(det_loss)
+        else:
+            det_out = self._upsample(det_out, imgs.size(), scale=4)
+            det_res = self.det_head.get_results(det_out, img_metas, cfg, scale=2)
+            outputs.update(det_res)
+
+        return outputs
+
+
+class FASTHead(nn.Module):
+    def __init__(self, conv, blocks, final, pooling_size,
+                 loss_text, loss_kernel, loss_emb, dropout_ratio=0):
+        super(FASTHead, self).__init__()
+        self.conv = RepConvLayer(in_channels=512, out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
+        if blocks is not None:
+            self.blocks = nn.ModuleList(blocks)
+        else:
+            self.blocks = None
+        self.final = ConvLayer(kernel_size=1, stride=1, dilation=1, groups=1, bias=False, has_shuffle=False, in_channels=128, out_channels=5, use_bn=False, act_func=None,
+                               dropout_rate=0, ops_order="weight")
+
+        self.pooling_size = pooling_size
+
+        if dropout_ratio > 0:
+            self.dropout = nn.Dropout2d(dropout_ratio)
+        else:
+            self.dropout = None
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+    def forward(self, x):
+        x = self.conv(x)
+        if self.blocks is not None:
+            for block in self.blocks:
+                x = block(x)
+        if self.dropout is not None:
+            x = self.dropout(x)
+        x = self.final(x)
+        return x
+
+
+class FASTNeck(nn.Module):
+    def __init__(self, reduce_layers = [64, 128, 256, 512]):
+        super(FASTNeck, self).__init__()
+        
+        self.reduce_layer1 = RepConvLayer(in_channels=reduce_layers[0], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
+        self.reduce_layer2 = RepConvLayer(in_channels=reduce_layers[1], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
+        self.reduce_layer3 = RepConvLayer(in_channels=reduce_layers[2], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
+        self.reduce_layer4 = RepConvLayer(in_channels=reduce_layers[3], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+    
+    def _upsample(self, x, y):
+        _, _, H, W = y.size()
+        return F.upsample(x, size=(H, W), mode='bilinear')
+    
+    def forward(self, x):
+        f1, f2, f3, f4 = x
+        f1 = self.reduce_layer1(f1)
+        f2 = self.reduce_layer2(f2)
+        f3 = self.reduce_layer3(f3)
+        f4 = self.reduce_layer4(f4)
+
+        f2 = self._upsample(f2, f1)
+        f3 = self._upsample(f3, f1)
+        f4 = self._upsample(f4, f1)
+        f = torch.cat((f1, f2, f3, f4), 1)
+        return f
+
+
+
+def _fast(
+    arch: str,
+    pretrained: bool,
+    backbone_fn: Callable[[bool], nn.Module],
+    pretrained_backbone: bool = True,
+    ignore_keys: Optional[List[str]] = None,
+    **kwargs: Any,
+) -> Fast:
+    pretrained_backbone = pretrained_backbone and not pretrained
+    
+    feat_extractor =  backbone_fn(pretrained_backbone)
+    neck = FASTNeck()
+    head = FASTHead()
+    
+    if not kwargs.get("class_names", None):
+        kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
+    else:
+        kwargs["class_names"] = sorted(kwargs["class_names"])
+
+    # Build the model
+    model = Fast(feat_extractor, neck, head, cfg=default_cfgs[arch], **kwargs)
+    # Load pretrained parameters
+    if pretrained:
+        # The number of class_names is not the same as the number of classes in the pretrained model =>
+        # remove the layer weights
+        _ignore_keys = (
+            ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
+        )
+        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
+
+    return model
+
+
+def fast_tiny(pretrained: bool = False, **kwargs: Any) -> Fast:
+    """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import fast_tiny
+    >>> model = fast_tiny(pretrained=True).eval()
+    >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+
+    Returns:
+        text detection architecture
+    """
+    return _fast(
+        "fast_tiny",
+        pretrained,
+        textnetfast_tiny,
+        # change ignore keys
+        ignore_keys=[
+            "classifier.6.weight",
+            "classifier.6.bias",
+        ],
+        **kwargs,
+    )
+# set fast_small and fast_base

From 2e1ea2f9db5bd5bbe29a0132bacbc92da4e71022 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Sun, 10 Sep 2023 16:07:37 +0200
Subject: [PATCH 37/44] [skip ci] backbone+neck+head of FAst torch ready,
 remains fast class

---
 doctr/models/detection/fast/pytorch.py | 27 ++++++--------------------
 1 file changed, 6 insertions(+), 21 deletions(-)

diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index 9788742ef2..eb87d81f6b 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -13,9 +13,9 @@
 
 from doctr.file_utils import CLASS_NAME
 from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny, textnetfast_small, textnetfast_base
+
 from ...utils import load_pretrained_params
-from .base import LinkNetPostProcessor, _LinkNet
-from models.loss import build_loss
+from .base import FastPostProcessor
 
 __all__ = ["fast_tiny", "fast_small", "fast_base"]
 
@@ -34,8 +34,8 @@
         "url": None,
     },
 }
-# reimplement FAST and DETECTION_HEAD
-# NECK AND TEXTNET READY
+# reimplement FAST Class
+# NECK AND TEXTNET AND HEAD READY
 
 class FAST(nn.Module):
     def __init__(self, backbone, neck, detection_head):
@@ -98,24 +98,15 @@ def forward(self, imgs, gt_texts=None, gt_kernels=None, training_masks=None,
 
 
 class FASTHead(nn.Module):
-    def __init__(self, conv, blocks, final, pooling_size,
-                 loss_text, loss_kernel, loss_emb, dropout_ratio=0):
+    def __init__(self, conv, final, pooling_size):
         super(FASTHead, self).__init__()
         self.conv = RepConvLayer(in_channels=512, out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
-        if blocks is not None:
-            self.blocks = nn.ModuleList(blocks)
-        else:
-            self.blocks = None
+
         self.final = ConvLayer(kernel_size=1, stride=1, dilation=1, groups=1, bias=False, has_shuffle=False, in_channels=128, out_channels=5, use_bn=False, act_func=None,
                                dropout_rate=0, ops_order="weight")
 
         self.pooling_size = pooling_size
 
-        if dropout_ratio > 0:
-            self.dropout = nn.Dropout2d(dropout_ratio)
-        else:
-            self.dropout = None
-
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 nn.init.kaiming_normal_(m.weight)
@@ -125,11 +116,6 @@ def __init__(self, conv, blocks, final, pooling_size,
 
     def forward(self, x):
         x = self.conv(x)
-        if self.blocks is not None:
-            for block in self.blocks:
-                x = block(x)
-        if self.dropout is not None:
-            x = self.dropout(x)
         x = self.final(x)
         return x
 
@@ -170,7 +156,6 @@ def forward(self, x):
         return f
 
 
-
 def _fast(
     arch: str,
     pretrained: bool,

From e83b2acc3565516546b65355acc917d626e35fde Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 11 Sep 2023 09:50:23 +0200
Subject: [PATCH 38/44] [skip ci] correcting some stuff for Fast Torch Model on
 style and quality tests

---
 doctr/models/detection/fast/pytorch.py | 160 ++++++++++++++++++-------
 doctr/models/modules/layers/pytorch.py | 150 +++++++++++++++++++++++
 2 files changed, 267 insertions(+), 43 deletions(-)

diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index eb87d81f6b..ff2bf50b13 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -3,23 +3,24 @@
 # This program is licensed under the Apache License 2.0.
 # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
 
-from typing import Any, Callable, Dict, List, Optional, Tuple
+import time
+from typing import Any, Callable, Dict, List, Optional
 
-import numpy as np
 import torch
-from torch import nn
-from torch.nn import functional as F
 import torch.nn as nn
+from torch.nn import functional as F
 
 from doctr.file_utils import CLASS_NAME
-from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny, textnetfast_small, textnetfast_base
+from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny
+from doctr.models.modules.layers.pytorch import ConvLayer, RepConvLayer
 
 from ...utils import load_pretrained_params
-from .base import FastPostProcessor
 
 __all__ = ["fast_tiny", "fast_small", "fast_base"]
 
 
+# modify the ignore_keys in fast_tiny, fast_small, fast_base
+
 default_cfgs: Dict[str, Dict[str, Any]] = {
     "fast_tiny": {
         "input_shape": (3, 1024, 1024),
@@ -37,19 +38,21 @@
 # reimplement FAST Class
 # NECK AND TEXTNET AND HEAD READY
 
+
 class FAST(nn.Module):
     def __init__(self, backbone, neck, detection_head):
         super(FAST, self).__init__()
-        self.backbone = backbone # okay
-        self.neck = neck #okay
+        self.backbone = backbone  # okay
+        self.neck = neck  # okay
         self.det_head = detection_head
-        
+
     def _upsample(self, x, size, scale=1):
         _, _, H, W = size
-        return F.interpolate(x, size=(H // scale, W // scale), mode='bilinear')
+        return F.interpolate(x, size=(H // scale, W // scale), mode="bilinear")
 
-    def forward(self, imgs, gt_texts=None, gt_kernels=None, training_masks=None,
-                gt_instances=None, img_metas=None, cfg=None):
+    def forward(
+        self, imgs, gt_texts=None, gt_kernels=None, training_masks=None, gt_instances=None, img_metas=None, cfg=None
+    ):
         outputs = dict()
 
         if not self.training:
@@ -61,19 +64,15 @@ def forward(self, imgs, gt_texts=None, gt_kernels=None, training_masks=None,
 
         if not self.training:
             torch.cuda.synchronize()
-            outputs.update(dict(
-                backbone_time=time.time() - start
-            ))
+            outputs.update(dict(backbone_time=time.time() - start))
             start = time.time()
 
         # reduce channel
         f = self.neck(f)
-        
+
         if not self.training:
             torch.cuda.synchronize()
-            outputs.update(dict(
-                neck_time=time.time() - start
-            ))
+            outputs.update(dict(neck_time=time.time() - start))
             start = time.time()
 
         # detection
@@ -81,9 +80,7 @@ def forward(self, imgs, gt_texts=None, gt_kernels=None, training_masks=None,
 
         if not self.training:
             torch.cuda.synchronize()
-            outputs.update(dict(
-                det_head_time=time.time() - start
-            ))
+            outputs.update(dict(det_head_time=time.time() - start))
 
         if self.training:
             det_out = self._upsample(det_out, imgs.size(), scale=1)
@@ -98,14 +95,24 @@ def forward(self, imgs, gt_texts=None, gt_kernels=None, training_masks=None,
 
 
 class FASTHead(nn.Module):
-    def __init__(self, conv, final, pooling_size):
+    def __init__(self):
         super(FASTHead, self).__init__()
         self.conv = RepConvLayer(in_channels=512, out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
 
-        self.final = ConvLayer(kernel_size=1, stride=1, dilation=1, groups=1, bias=False, has_shuffle=False, in_channels=128, out_channels=5, use_bn=False, act_func=None,
-                               dropout_rate=0, ops_order="weight")
-
-        self.pooling_size = pooling_size
+        self.final = ConvLayer(
+            kernel_size=1,
+            stride=1,
+            dilation=1,
+            groups=1,
+            bias=False,
+            has_shuffle=False,
+            in_channels=128,
+            out_channels=5,
+            use_bn=False,
+            act_func=None,
+            dropout_rate=0,
+            ops_order="weight",
+        )
 
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
@@ -121,13 +128,21 @@ def forward(self, x):
 
 
 class FASTNeck(nn.Module):
-    def __init__(self, reduce_layers = [64, 128, 256, 512]):
+    def __init__(self, reduce_layers=[64, 128, 256, 512]):
         super(FASTNeck, self).__init__()
-        
-        self.reduce_layer1 = RepConvLayer(in_channels=reduce_layers[0], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
-        self.reduce_layer2 = RepConvLayer(in_channels=reduce_layers[1], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
-        self.reduce_layer3 = RepConvLayer(in_channels=reduce_layers[2], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
-        self.reduce_layer4 = RepConvLayer(in_channels=reduce_layers[3], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1)
+
+        self.reduce_layer1 = RepConvLayer(
+            in_channels=reduce_layers[0], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1
+        )
+        self.reduce_layer2 = RepConvLayer(
+            in_channels=reduce_layers[1], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1
+        )
+        self.reduce_layer3 = RepConvLayer(
+            in_channels=reduce_layers[2], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1
+        )
+        self.reduce_layer4 = RepConvLayer(
+            in_channels=reduce_layers[3], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1
+        )
         self._initialize_weights()
 
     def _initialize_weights(self):
@@ -137,11 +152,11 @@ def _initialize_weights(self):
             elif isinstance(m, nn.BatchNorm2d):
                 m.weight.data.fill_(1)
                 m.bias.data.zero_()
-    
+
     def _upsample(self, x, y):
         _, _, H, W = y.size()
-        return F.upsample(x, size=(H, W), mode='bilinear')
-    
+        return F.upsample(x, size=(H, W), mode="bilinear")
+
     def forward(self, x):
         f1, f2, f3, f4 = x
         f1 = self.reduce_layer1(f1)
@@ -163,20 +178,20 @@ def _fast(
     pretrained_backbone: bool = True,
     ignore_keys: Optional[List[str]] = None,
     **kwargs: Any,
-) -> Fast:
+) -> FAST:
     pretrained_backbone = pretrained_backbone and not pretrained
-    
-    feat_extractor =  backbone_fn(pretrained_backbone)
+
+    feat_extractor = backbone_fn(pretrained_backbone)
     neck = FASTNeck()
     head = FASTHead()
-    
+
     if not kwargs.get("class_names", None):
         kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
     else:
         kwargs["class_names"] = sorted(kwargs["class_names"])
 
     # Build the model
-    model = Fast(feat_extractor, neck, head, cfg=default_cfgs[arch], **kwargs)
+    model = FAST(feat_extractor, neck, head, cfg=default_cfgs[arch], **kwargs)
     # Load pretrained parameters
     if pretrained:
         # The number of class_names is not the same as the number of classes in the pretrained model =>
@@ -189,7 +204,7 @@ def _fast(
     return model
 
 
-def fast_tiny(pretrained: bool = False, **kwargs: Any) -> Fast:
+def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
     """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
     Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
     Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
@@ -217,4 +232,63 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> Fast:
         ],
         **kwargs,
     )
-# set fast_small and fast_base
+
+
+def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
+    """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import fast_small
+    >>> model = fast_small(pretrained=True).eval()
+    >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+
+    Returns:
+        text detection architecture
+    """
+    return _fast(
+        "fast_small",
+        pretrained,
+        textnetfast_tiny,
+        # change ignore keys
+        ignore_keys=[
+            "classifier.6.weight",
+            "classifier.6.bias",
+        ],
+        **kwargs,
+    )
+
+
+def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
+    """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
+    Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
+    Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.
+
+    >>> import torch
+    >>> from doctr.models import fast_base
+    >>> model = fast_base(pretrained=True).eval()
+    >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32)
+    >>> out = model(input_tensor)
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on our text detection dataset
+
+    Returns:
+        text detection architecture
+    """
+    return _fast(
+        "fast_base",
+        pretrained,
+        textnetfast_tiny,
+        # change ignore keys
+        ignore_keys=[
+            "classifier.6.weight",
+            "classifier.6.bias",
+        ],
+        **kwargs,
+    )
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index c54be1bd07..cb14dd236c 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -1,3 +1,4 @@
+from collections import OrderedDict
 from typing import Any, Union
 
 import numpy as np
@@ -7,6 +8,17 @@
 __all__ = ["RepConvLayer"]
 
 
+def get_same_padding(kernel_size):
+    if isinstance(kernel_size, tuple):
+        assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size
+        p1 = get_same_padding(kernel_size[0])
+        p2 = get_same_padding(kernel_size[1])
+        return p1, p2
+    assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`"
+    assert kernel_size % 2 > 0, "kernel size should be odd number"
+    return kernel_size // 2
+
+
 class RepConvLayer(nn.Module):
     """Reparameterized Convolutional Layer"""
 
@@ -224,3 +236,141 @@ def config(self):
     @staticmethod
     def build_from_config(config):
         return RepConvLayer(**config)
+
+
+class My2DLayer(nn.Module):
+    def __init__(
+        self, in_channels, out_channels, use_bn=True, act_func="relu", dropout_rate=0, ops_order="weight_bn_act"
+    ):
+        super(My2DLayer, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+
+        self.use_bn = use_bn
+        self.act_func = act_func
+        self.dropout_rate = dropout_rate
+        self.ops_order = ops_order
+
+        """ modules """
+        modules = {}
+        # batch norm
+        if self.use_bn:
+            if self.bn_before_weight:
+                modules["bn"] = nn.BatchNorm2d(in_channels)
+            else:
+                modules["bn"] = nn.BatchNorm2d(out_channels)
+        else:
+            modules["bn"] = None  # type: ignore[assignment]
+
+        if self.dropout_rate > 0:
+            modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)  # type: ignore[assignment]
+        else:
+            modules["dropout"] = None  # type: ignore[assignment]
+        # weight
+        modules["weight"] = self.weight_op()  # type: ignore[operator]
+
+        # add modules
+        for op in self.ops_list:
+            if modules[op] is None:
+                continue
+            elif op == "weight":
+                if modules["dropout"] is not None:
+                    self.add_module("dropout", modules["dropout"])
+                for key in modules["weight"]:  # type ignore[attr-defined]
+                    self.add_module(key, modules["weight"][key])
+            else:
+                self.add_module(op, modules[op])
+
+    @property
+    def ops_list(self):
+        return self.ops_order.split("_")
+
+    @property
+    def bn_before_weight(self):
+        for op in self.ops_list:
+            if op == "bn":
+                return True
+            elif op == "weight":
+                return False
+        raise ValueError("Invalid ops_order: %s" % self.ops_order)
+
+    """ Methods defined in MyModule """
+
+    def forward(self, x):
+        for module in self._modules.values():
+            x = module(x)
+        return x
+
+    @property
+    def module_str(self):
+        raise NotImplementedError
+
+    @property
+    def config(self):
+        return {
+            "in_channels": self.in_channels,
+            "out_channels": self.out_channels,
+            "use_bn": self.use_bn,
+            "act_func": self.act_func,
+            "dropout_rate": self.dropout_rate,
+            "ops_order": self.ops_order,
+        }
+
+    @staticmethod
+    def build_from_config(config):
+        raise NotImplementedError
+
+    def get_flops(self, x):
+        raise NotImplementedError
+
+    @staticmethod
+    def is_zero_layer():
+        return False
+
+
+class ConvLayer(My2DLayer):
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=3,
+        stride=1,
+        dilation=1,
+        groups=1,
+        bias=False,
+        has_shuffle=False,
+        use_bn=True,
+        act_func="relu",
+        dropout_rate=0,
+        ops_order="weight_bn_act",
+    ):
+        self.kernel_size = kernel_size
+        self.stride = stride
+        self.dilation = dilation
+        self.groups = groups
+        self.bias = bias
+        self.has_shuffle = has_shuffle
+
+        super(ConvLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
+
+    def weight_op(self):
+        padding = get_same_padding(self.kernel_size)
+        if isinstance(padding, int):
+            padding *= self.dilation
+        else:
+            padding[0] *= self.dilation
+            padding[1] *= self.dilation
+
+        weight_dict = OrderedDict()
+        weight_dict["conv"] = nn.Conv2d(
+            self.in_channels,
+            self.out_channels,
+            kernel_size=self.kernel_size,
+            stride=self.stride,
+            padding=padding,
+            dilation=self.dilation,
+            groups=self.groups,
+            bias=self.bias,
+        )
+
+        return weight_dict

From e3aa43fc77e8e33852a0f9872190315db1f2eb90 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Mon, 11 Sep 2023 10:17:05 +0200
Subject: [PATCH 39/44] [skip ci] correcting some stuff for convlayer

---
 doctr/models/modules/layers/pytorch.py | 78 +++++++++++---------------
 1 file changed, 33 insertions(+), 45 deletions(-)

diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index cb14dd236c..3504ad9ca5 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -19,6 +19,21 @@ def get_same_padding(kernel_size):
     return kernel_size // 2
 
 
+def build_activation(act_func, inplace=True):
+    if act_func == "relu":
+        return nn.ReLU(inplace=inplace)
+    elif act_func == "relu6":
+        return nn.ReLU6(inplace=inplace)
+    elif act_func == "tanh":
+        return nn.Tanh()
+    elif act_func == "sigmoid":
+        return nn.Sigmoid()
+    elif act_func is None:
+        return None
+    else:
+        raise ValueError("do not support: %s" % act_func)
+
+
 class RepConvLayer(nn.Module):
     """Reparameterized Convolutional Layer"""
 
@@ -171,15 +186,15 @@ def switch_to_deploy(self):
         self.fused_conv = nn.Conv2d(
             in_channels=self.main_conv.in_channels,
             out_channels=self.main_conv.out_channels,
-            kernel_size=self.main_conv.kernel_size,  # type: ignore
-            stride=self.main_conv.stride,  # type: ignore
-            padding=self.main_conv.padding,  # type: ignore
-            dilation=self.main_conv.dilation,  # type: ignore
+            kernel_size=self.main_conv.kernel_size,
+            stride=self.main_conv.stride,
+            padding=self.main_conv.padding,
+            dilation=self.main_conv.dilation,
             groups=self.main_conv.groups,
             bias=True,
         )
         self.fused_conv.weight.data = kernel
-        self.fused_conv.bias.data = bias  # type: ignore
+        self.fused_conv.bias.data = bias
         self.deploy = True
         for para in self.parameters():
             para.detach_()
@@ -193,17 +208,16 @@ def switch_to_deploy(self):
     def switch_to_test(self):
         kernel, bias = self.get_equivalent_kernel_bias()
         self.fused_conv = nn.Conv2d(
-            in_channels=self.main_conv.in_channels,
             out_channels=self.main_conv.out_channels,
-            kernel_size=self.main_conv.kernel_size,  # type: ignore
-            stride=self.main_conv.stride,  # type: ignore
-            padding=self.main_conv.padding,  # type: ignore
-            dilation=self.main_conv.dilation,  # type: ignore
+            kernel_size=self.main_conv.kernel_size,
+            stride=self.main_conv.stride,
+            padding=self.main_conv.padding,
+            dilation=self.main_conv.dilation,
             groups=self.main_conv.groups,
             bias=True,
         )
         self.fused_conv.weight.data = kernel  # type ignore[operator]
-        self.fused_conv.bias.data = bias  # type: ignore
+        self.fused_conv.bias.data = bias
         for para in self.fused_conv.parameters():
             para.detach_()
         self.deploy = True
@@ -213,10 +227,6 @@ def switch_to_train(self):
             self.__delattr__("fused_conv")
         self.deploy = False
 
-    @staticmethod
-    def is_zero_layer():
-        return False
-
     @property
     def module_str(self):
         return "Rep_%dx%d" % (self.kernel_size[0], self.kernel_size[1])
@@ -260,14 +270,16 @@ def __init__(
             else:
                 modules["bn"] = nn.BatchNorm2d(out_channels)
         else:
-            modules["bn"] = None  # type: ignore[assignment]
-
+            modules["bn"] = None
+        # activation
+        modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act")
+        # dropout
         if self.dropout_rate > 0:
-            modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)  # type: ignore[assignment]
+            modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True)
         else:
-            modules["dropout"] = None  # type: ignore[assignment]
+            modules["dropout"] = None
         # weight
-        modules["weight"] = self.weight_op()  # type: ignore[operator]
+        modules["weight"] = self.weight_op()
 
         # add modules
         for op in self.ops_list:
@@ -276,7 +288,7 @@ def __init__(
             elif op == "weight":
                 if modules["dropout"] is not None:
                     self.add_module("dropout", modules["dropout"])
-                for key in modules["weight"]:  # type ignore[attr-defined]
+                for key in modules["weight"]:
                     self.add_module(key, modules["weight"][key])
             else:
                 self.add_module(op, modules[op])
@@ -294,35 +306,11 @@ def bn_before_weight(self):
                 return False
         raise ValueError("Invalid ops_order: %s" % self.ops_order)
 
-    """ Methods defined in MyModule """
-
     def forward(self, x):
         for module in self._modules.values():
             x = module(x)
         return x
 
-    @property
-    def module_str(self):
-        raise NotImplementedError
-
-    @property
-    def config(self):
-        return {
-            "in_channels": self.in_channels,
-            "out_channels": self.out_channels,
-            "use_bn": self.use_bn,
-            "act_func": self.act_func,
-            "dropout_rate": self.dropout_rate,
-            "ops_order": self.ops_order,
-        }
-
-    @staticmethod
-    def build_from_config(config):
-        raise NotImplementedError
-
-    def get_flops(self, x):
-        raise NotImplementedError
-
     @staticmethod
     def is_zero_layer():
         return False

From cd35227dc6318b5f71c3c19810ee7e93fe9a6264 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 13 Sep 2023 08:26:02 +0200
Subject: [PATCH 40/44] [skip ci] correcting some stuff for Fast Torch Model on
 style and quality tests

---
 doctr/models/detection/fast/pytorch.py | 90 +++++++++++++++++++-------
 1 file changed, 66 insertions(+), 24 deletions(-)

diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index ff2bf50b13..ae36b6ca7d 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -3,7 +3,6 @@
 # This program is licensed under the Apache License 2.0.
 # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
 
-import time
 from typing import Any, Callable, Dict, List, Optional
 
 import torch
@@ -40,59 +39,100 @@
 
 
 class FAST(nn.Module):
-    def __init__(self, backbone, neck, detection_head):
-        super(FAST, self).__init__()
-        self.backbone = backbone  # okay
-        self.neck = neck  # okay
-        self.det_head = detection_head
-
-    def _upsample(self, x, size, scale=1):
-        _, _, H, W = size
-        return F.interpolate(x, size=(H // scale, W // scale), mode="bilinear")
+    def __init__(self,
+                 feat_extractor,
+                 bin_thresh: float = 0.1,
+                 head_chans: int = 32,
+                 assume_straight_pages: bool = True,
+                 exportable: bool = False,
+                 cfg: Optional[Dict[str, Any]] = None,
+                 class_names: List[str] = [CLASS_NAME],
+                 ) -> None:
+        super().__init__()
+        self.class_names = class_names
+        num_classes: int = len(self.class_names)
+        self.cfg = cfg
+        self.exportable = exportable
+        self.assume_straight_pages = assume_straight_pages
+
+        self.feat_extractor = feat_extractor
+
+        self.feat_extractor.train()
+
+        self.fpn = FASTNeck()
+        
+        self.classifier = FASTHead()
+
+        self.postprocessor = FastPostProcessor(
+            assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
+        )
+        
+        # AJUSTER LES INITIALISATION POUR LE MODELE
+        for n, m in self.named_modules():
+            # Don't override the initialization of the backbone
+            if n.startswith("feat_extractor."):
+                continue
+            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+                nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu")
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1.0)
+                m.bias.data.zero_()
 
     def forward(
-        self, imgs, gt_texts=None, gt_kernels=None, training_masks=None, gt_instances=None, img_metas=None, cfg=None
-    ):
+        self
+        x: torch.Tensor,
+        
+        # MODIFIER LES PARAMETRES CI-DESSOUS POUR LES PARAMETRES PLUS BAS
+        gt_texts=None, gt_kernels=None, training_masks=None, gt_instances=None, img_metas=None, cfg=None
+        
+        target: Optional[List[np.ndarray]] = None,
+        return_model_output: bool = False,
+        return_preds: bool = False,
+    ) -> Dict[str, torch.Tensor]:
         outputs = dict()
 
         if not self.training:
             torch.cuda.synchronize()
-            start = time.time()
 
         # backbone
-        f = self.backbone(imgs)
+        f = self.backbone(x)
 
         if not self.training:
             torch.cuda.synchronize()
-            outputs.update(dict(backbone_time=time.time() - start))
-            start = time.time()
 
         # reduce channel
         f = self.neck(f)
 
         if not self.training:
             torch.cuda.synchronize()
-            outputs.update(dict(neck_time=time.time() - start))
-            start = time.time()
 
         # detection
-        det_out = self.det_head(f)
+        det_out = self.classifier(f)
 
         if not self.training:
             torch.cuda.synchronize()
-            outputs.update(dict(det_head_time=time.time() - start))
 
         if self.training:
-            det_out = self._upsample(det_out, imgs.size(), scale=1)
+            det_out = self._upsample(det_out, x.size(), scale=1)
+            
+            # MODIFEIER SELF.DET_HEAD.LOSS en SELF.LOSS
             det_loss = self.det_head.loss(det_out, gt_texts, gt_kernels, training_masks, gt_instances)
             outputs.update(det_loss)
         else:
-            det_out = self._upsample(det_out, imgs.size(), scale=4)
+            det_out = self._upsample(det_out, x.size(), scale=4)
+            
+            # MODIFIER SELF.DET_HEAD.GET_RESULTS en SELF.GET_RESULTS ou self.postprocessing
             det_res = self.det_head.get_results(det_out, img_metas, cfg, scale=2)
             outputs.update(det_res)
 
         return outputs
 
+    def _upsample(self, x, size, scale=1):
+        _, _, H, W = size
+        return F.interpolate(x, size=(H // scale, W // scale), mode="bilinear")
+        
 
 class FASTHead(nn.Module):
     def __init__(self):
@@ -181,9 +221,11 @@ def _fast(
 ) -> FAST:
     pretrained_backbone = pretrained_backbone and not pretrained
 
-    feat_extractor = backbone_fn(pretrained_backbone)
+    backbone = backbone_fn(pretrained_backbone)
     neck = FASTNeck()
     head = FASTHead()
+    
+    feat_extractor = backbone
 
     if not kwargs.get("class_names", None):
         kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME])
@@ -191,7 +233,7 @@ def _fast(
         kwargs["class_names"] = sorted(kwargs["class_names"])
 
     # Build the model
-    model = FAST(feat_extractor, neck, head, cfg=default_cfgs[arch], **kwargs)
+    model = FAST(feat_extractor=feat_extractor, cfg=default_cfgs[arch], **kwargs)
     # Load pretrained parameters
     if pretrained:
         # The number of class_names is not the same as the number of classes in the pretrained model =>

From 6a5b5a6dad64a8c29f8e8bb32f3014fa75a740fe Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 13 Sep 2023 09:08:26 +0200
Subject: [PATCH 41/44] [skip ci] implementation of Fast torch model update
 (working(init), not working (forward,compute_loss,postprocessor))

---
 doctr/models/detection/fast/base.py    | 112 +--------------------
 doctr/models/detection/fast/pytorch.py | 130 +++++++++++++++----------
 2 files changed, 82 insertions(+), 160 deletions(-)

diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py
index 02c3936774..0828bbfc11 100644
--- a/doctr/models/detection/fast/base.py
+++ b/doctr/models/detection/fast/base.py
@@ -16,10 +16,10 @@
 
 from ..core import DetectionPostProcessor
 
-__all__ = ["_LinkNet", "LinkNetPostProcessor"]
+__all__ = ["FastPostProcessor"]
 
 
-class LinkNetPostProcessor(DetectionPostProcessor):
+class FastPostProcessor(DetectionPostProcessor):
     """Implements a post processor for LinkNet model.
 
     Args:
@@ -139,111 +139,3 @@ def bitmap_to_boxes(
             return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype)
         else:
             return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype)
-
-
-class _LinkNet(BaseModel):
-    """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation"
-    <https://arxiv.org/pdf/1707.03718.pdf>`_.
-
-    Args:
-        out_chan: number of channels for the output
-    """
-
-    min_size_box: int = 3
-    assume_straight_pages: bool = True
-    shrink_ratio = 0.5
-
-    def build_target(
-        self,
-        target: List[Dict[str, np.ndarray]],
-        output_shape: Tuple[int, int, int],
-        channels_last: bool = True,
-    ) -> Tuple[np.ndarray, np.ndarray]:
-        """Build the target, and it's mask to be used from loss computation.
-
-        Args:
-            target: target coming from dataset
-            output_shape: shape of the output of the model without batch_size
-            channels_last: whether channels are last or not
-
-        Returns:
-            the new formatted target and the mask
-        """
-
-        if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
-            raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
-        if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
-            raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")
-
-        h: int
-        w: int
-        if channels_last:
-            h, w, num_classes = output_shape
-        else:
-            num_classes, h, w = output_shape
-        target_shape = (len(target), num_classes, h, w)
-
-        seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
-        seg_mask: np.ndarray = np.ones(target_shape, dtype=bool)
-
-        for idx, tgt in enumerate(target):
-            for class_idx, _tgt in enumerate(tgt.values()):
-                # Draw each polygon on gt
-                if _tgt.shape[0] == 0:
-                    # Empty image, full masked
-                    seg_mask[idx, class_idx] = False
-
-                # Absolute bounding boxes
-                abs_boxes = _tgt.copy()
-
-                if abs_boxes.ndim == 3:
-                    abs_boxes[:, :, 0] *= w
-                    abs_boxes[:, :, 1] *= h
-                    polys = abs_boxes
-                    boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1)
-                    abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32)
-                else:
-                    abs_boxes[:, [0, 2]] *= w
-                    abs_boxes[:, [1, 3]] *= h
-                    abs_boxes = abs_boxes.round().astype(np.int32)
-                    polys = np.stack(
-                        [
-                            abs_boxes[:, [0, 1]],
-                            abs_boxes[:, [0, 3]],
-                            abs_boxes[:, [2, 3]],
-                            abs_boxes[:, [2, 1]],
-                        ],
-                        axis=1,
-                    )
-                    boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1])
-
-                for poly, box, box_size in zip(polys, abs_boxes, boxes_size):
-                    # Mask boxes that are too small
-                    if box_size < self.min_size_box:
-                        seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
-                        continue
-
-                    # Negative shrink for gt, as described in paper
-                    polygon = Polygon(poly)
-                    distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length
-                    subject = [tuple(coor) for coor in poly]
-                    padding = pyclipper.PyclipperOffset()
-                    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
-                    shrunken = padding.Execute(-distance)
-
-                    # Draw polygon on gt if it is valid
-                    if len(shrunken) == 0:
-                        seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
-                        continue
-                    shrunken = np.array(shrunken[0]).reshape(-1, 2)
-                    if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
-                        seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
-                        continue
-                    cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1)
-
-        # Don't forget to switch back to channel last if Tensorflow is used
-        if channels_last:
-            seg_target = seg_target.transpose((0, 2, 3, 1))
-            seg_mask = seg_mask.transpose((0, 2, 3, 1))
-
-        return seg_target, seg_mask
diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index ae36b6ca7d..ed7e4b95d1 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -5,6 +5,8 @@
 
 from typing import Any, Callable, Dict, List, Optional
 
+import numpy as np
+
 import torch
 import torch.nn as nn
 from torch.nn import functional as F
@@ -13,13 +15,13 @@
 from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny
 from doctr.models.modules.layers.pytorch import ConvLayer, RepConvLayer
 
+from .base import FastPostProcessor
+
 from ...utils import load_pretrained_params
 
 __all__ = ["fast_tiny", "fast_small", "fast_base"]
 
 
-# modify the ignore_keys in fast_tiny, fast_small, fast_base
-
 default_cfgs: Dict[str, Dict[str, Any]] = {
     "fast_tiny": {
         "input_shape": (3, 1024, 1024),
@@ -34,11 +36,11 @@
         "url": None,
     },
 }
-# reimplement FAST Class
-# NECK AND TEXTNET AND HEAD READY
 
+# implement FastPostProcessing class
 
 class FAST(nn.Module):
+
     def __init__(self,
                  feat_extractor,
                  bin_thresh: float = 0.1,
@@ -54,20 +56,12 @@ def __init__(self,
         self.cfg = cfg
         self.exportable = exportable
         self.assume_straight_pages = assume_straight_pages
-
         self.feat_extractor = feat_extractor
-
         self.feat_extractor.train()
-
         self.fpn = FASTNeck()
-        
         self.classifier = FASTHead()
-
-        self.postprocessor = FastPostProcessor(
-            assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh
-        )
+        self.postprocessor = FastPostProcessor(assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh)
         
-        # AJUSTER LES INITIALISATION POUR LE MODELE
         for n, m in self.named_modules():
             # Don't override the initialization of the backbone
             if n.startswith("feat_extractor."):
@@ -81,58 +75,91 @@ def __init__(self,
                 m.bias.data.zero_()
 
     def forward(
-        self
+        self,
         x: torch.Tensor,
-        
-        # MODIFIER LES PARAMETRES CI-DESSOUS POUR LES PARAMETRES PLUS BAS
-        gt_texts=None, gt_kernels=None, training_masks=None, gt_instances=None, img_metas=None, cfg=None
-        
         target: Optional[List[np.ndarray]] = None,
         return_model_output: bool = False,
         return_preds: bool = False,
-    ) -> Dict[str, torch.Tensor]:
-        outputs = dict()
-
-        if not self.training:
-            torch.cuda.synchronize()
-
-        # backbone
-        f = self.backbone(x)
-
+    ) -> Dict[str, torch.Tensor]:  
+                      
+        out: Dict[str, Any] = {}
         if not self.training:
             torch.cuda.synchronize()
-
-        # reduce channel
-        f = self.neck(f)
-
+        feats = self.feat_extractor(x)       
         if not self.training:
             torch.cuda.synchronize()
-
-        # detection
-        det_out = self.classifier(f)
-
+        logits = self.fpn(feats)
         if not self.training:
             torch.cuda.synchronize()
+        logits = self.classifier(logits)
+        logits = self._upsample(logits, x.size(), scale=1)
+        if self.exportable:
+            out["logits"] = logits
+            return out
+
+        # A T ON REELEMENT BESOIN DE LA SIGMOID ICI ?
+        if return_model_output or target is None or return_preds:
+            prob_map = torch.sigmoid(logits)
+
+        if return_model_output:
+            out["out_map"] = prob_map
+        if target is None or return_preds:
+            # Post-process boxes
+            out["preds"] = [
+                dict(zip(self.class_names, preds))
+                for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
+            ]
+        if target is not None:
+            loss = self.compute_loss(logits, target)
+            out["loss"] = loss
+        return out
+
+    def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
+        # self.compute_loss(det_out, gt_texts, gt_kernels, training_masks, gt_instances)
+
+        # output
+        kernels = out[:, 0, :, :]  # 4*640*640
+        texts = self._max_pooling(kernels, scale=1)  # 4*640*640
+        embs = out[:, 1:, :, :]  # 4*4*640*640
+
+        # text loss
+        selected_masks = ohem_batch(texts, gt_texts, training_masks)
+        loss_text = self.text_loss(texts, gt_texts, selected_masks, reduce=False)
+        iou_text = iou((texts > 0).long(), gt_texts, training_masks, reduce=False)
+        losses = dict(
+            loss_text=loss_text,
+            iou_text=iou_text
+        )
+    
+        # kernel loss
+        selected_masks = gt_texts * training_masks
+        loss_kernel = self.kernel_loss(kernels, gt_kernels, selected_masks, reduce=False)
+        loss_kernel = torch.mean(loss_kernel, dim=0)
+        iou_kernel = iou((kernels > 0).long(), gt_kernels, selected_masks, reduce=False)
+        losses.update(dict(
+            loss_kernels=loss_kernel,
+            iou_kernel=iou_kernel
+        ))
+    
+        # auxiliary loss
+        loss_emb = self.emb_loss(embs, gt_instances, gt_kernels, training_masks, reduce=False)
+        losses.update(dict(
+            loss_emb=loss_emb
+        ))
+    
+        return losses
 
-        if self.training:
-            det_out = self._upsample(det_out, x.size(), scale=1)
-            
-            # MODIFEIER SELF.DET_HEAD.LOSS en SELF.LOSS
-            det_loss = self.det_head.loss(det_out, gt_texts, gt_kernels, training_masks, gt_instances)
-            outputs.update(det_loss)
-        else:
-            det_out = self._upsample(det_out, x.size(), scale=4)
-            
-            # MODIFIER SELF.DET_HEAD.GET_RESULTS en SELF.GET_RESULTS ou self.postprocessing
-            det_res = self.det_head.get_results(det_out, img_metas, cfg, scale=2)
-            outputs.update(det_res)
-
-        return outputs
+    def _max_pooling(self, x, scale=1):
+        if scale == 1:
+            x = self.pooling_1s(x)
+        elif scale == 2:
+            x = self.pooling_2s(x)
+        return x
 
     def _upsample(self, x, size, scale=1):
         _, _, H, W = size
         return F.interpolate(x, size=(H // scale, W // scale), mode="bilinear")
-        
+
 
 class FASTHead(nn.Module):
     def __init__(self):
@@ -246,6 +273,9 @@ def _fast(
     return model
 
 
+# modify the ignore_keys in fast_tiny, fast_small, fast_base
+
+
 def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
     """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
     Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.

From 1bc5ba4ef4fb17d94724240570b42d1aca1b7040 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 13 Sep 2023 11:04:47 +0200
Subject: [PATCH 42/44] [skip ci] update losses for Fast Torch model

---
 doctr/models/detection/fast/base.py    |   4 +-
 doctr/models/detection/fast/pytorch.py | 200 ++++++++++++++++++-------
 doctr/models/modules/layers/pytorch.py |   2 +-
 3 files changed, 146 insertions(+), 60 deletions(-)

diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py
index 0828bbfc11..727e10e006 100644
--- a/doctr/models/detection/fast/base.py
+++ b/doctr/models/detection/fast/base.py
@@ -5,15 +5,13 @@
 
 # Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization
 
-from typing import Dict, List, Tuple, Union
+from typing import List, Union
 
 import cv2
 import numpy as np
 import pyclipper
 from shapely.geometry import Polygon
 
-from doctr.models.core import BaseModel
-
 from ..core import DetectionPostProcessor
 
 __all__ = ["FastPostProcessor"]
diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index ed7e4b95d1..a5f008f2ab 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -6,7 +6,6 @@
 from typing import Any, Callable, Dict, List, Optional
 
 import numpy as np
-
 import torch
 import torch.nn as nn
 from torch.nn import functional as F
@@ -14,10 +13,10 @@
 from doctr.file_utils import CLASS_NAME
 from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny
 from doctr.models.modules.layers.pytorch import ConvLayer, RepConvLayer
-
-from .base import FastPostProcessor
+from doctr.utils.metrics import box_iou
 
 from ...utils import load_pretrained_params
+from .base import FastPostProcessor
 
 __all__ = ["fast_tiny", "fast_small", "fast_base"]
 
@@ -39,20 +38,21 @@
 
 # implement FastPostProcessing class
 
-class FAST(nn.Module):
 
-    def __init__(self,
-                 feat_extractor,
-                 bin_thresh: float = 0.1,
-                 head_chans: int = 32,
-                 assume_straight_pages: bool = True,
-                 exportable: bool = False,
-                 cfg: Optional[Dict[str, Any]] = None,
-                 class_names: List[str] = [CLASS_NAME],
-                 ) -> None:
+class FAST(nn.Module):
+    def __init__(
+        self,
+        feat_extractor,
+        bin_thresh: float = 0.1,
+        head_chans: int = 32,
+        assume_straight_pages: bool = True,
+        exportable: bool = False,
+        cfg: Optional[Dict[str, Any]] = None,
+        class_names: List[str] = [CLASS_NAME],
+    ) -> None:
         super().__init__()
         self.class_names = class_names
-        num_classes: int = len(self.class_names)
+        self.num_classes = len(self.class_names)
         self.cfg = cfg
         self.exportable = exportable
         self.assume_straight_pages = assume_straight_pages
@@ -61,7 +61,7 @@ def __init__(self,
         self.fpn = FASTNeck()
         self.classifier = FASTHead()
         self.postprocessor = FastPostProcessor(assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh)
-        
+
         for n, m in self.named_modules():
             # Don't override the initialization of the backbone
             if n.startswith("feat_extractor."):
@@ -80,12 +80,14 @@ def forward(
         target: Optional[List[np.ndarray]] = None,
         return_model_output: bool = False,
         return_preds: bool = False,
-    ) -> Dict[str, torch.Tensor]:  
-                      
+    ) -> Dict[str, torch.Tensor]:
+        # MODIFIER LES PARAMETRES CI-DESSOUS POUR LES PARAMETRES
+        # gt_kernels=None, training_masks=None, gt_instances=None, img_metas=None,
+
         out: Dict[str, Any] = {}
         if not self.training:
             torch.cuda.synchronize()
-        feats = self.feat_extractor(x)       
+        feats = self.feat_extractor(x)
         if not self.training:
             torch.cuda.synchronize()
         logits = self.fpn(feats)
@@ -97,12 +99,12 @@ def forward(
             out["logits"] = logits
             return out
 
-        # A T ON REELEMENT BESOIN DE LA SIGMOID ICI ?
         if return_model_output or target is None or return_preds:
             prob_map = torch.sigmoid(logits)
 
         if return_model_output:
             out["out_map"] = prob_map
+
         if target is None or return_preds:
             # Post-process boxes
             out["preds"] = [
@@ -114,39 +116,29 @@ def forward(
             out["loss"] = loss
         return out
 
-    def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
-        # self.compute_loss(det_out, gt_texts, gt_kernels, training_masks, gt_instances)
+    def compute_loss(self, out_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
+        # IL MANQUE CES PARAMATRES (gt_kernels, training_masks, gt_instances)
 
         # output
-        kernels = out[:, 0, :, :]  # 4*640*640
+        kernels = out_map[:, 0, :, :]  # 4*640*640
         texts = self._max_pooling(kernels, scale=1)  # 4*640*640
-        embs = out[:, 1:, :, :]  # 4*4*640*640
+        embs = out_map[:, 1:, :, :]  # 4*4*640*640
 
         # text loss
-        selected_masks = ohem_batch(texts, gt_texts, training_masks)
-        loss_text = self.text_loss(texts, gt_texts, selected_masks, reduce=False)
-        iou_text = iou((texts > 0).long(), gt_texts, training_masks, reduce=False)
-        losses = dict(
-            loss_text=loss_text,
-            iou_text=iou_text
-        )
-    
+        loss_text = multiclass_dice_loss(texts, target, self.num_classes, loss_weight=0.25)
+        iou_text = box_iou((texts > 0).long(), target)
+        losses = dict(loss_text=loss_text, iou_text=iou_text)
+
         # kernel loss
-        selected_masks = gt_texts * training_masks
-        loss_kernel = self.kernel_loss(kernels, gt_kernels, selected_masks, reduce=False)
+        loss_kernel = multiclass_dice_loss(kernels, None, self.num_classes, loss_weight=1.0)
         loss_kernel = torch.mean(loss_kernel, dim=0)
-        iou_kernel = iou((kernels > 0).long(), gt_kernels, selected_masks, reduce=False)
-        losses.update(dict(
-            loss_kernels=loss_kernel,
-            iou_kernel=iou_kernel
-        ))
-    
+        iou_kernel = box_iou((kernels > 0).long(), None)
+        losses.update(dict(loss_kernels=loss_kernel, iou_kernel=iou_kernel))
+
         # auxiliary loss
-        loss_emb = self.emb_loss(embs, gt_instances, gt_kernels, training_masks, reduce=False)
-        losses.update(dict(
-            loss_emb=loss_emb
-        ))
-    
+        loss_emb = emb_loss_v2(embs, None, None, None)
+        losses.update(dict(loss_emb=loss_emb))
+
         return losses
 
     def _max_pooling(self, x, scale=1):
@@ -249,9 +241,9 @@ def _fast(
     pretrained_backbone = pretrained_backbone and not pretrained
 
     backbone = backbone_fn(pretrained_backbone)
-    neck = FASTNeck()
-    head = FASTHead()
-    
+    FASTNeck()
+    FASTHead()
+
     feat_extractor = backbone
 
     if not kwargs.get("class_names", None):
@@ -273,9 +265,6 @@ def _fast(
     return model
 
 
-# modify the ignore_keys in fast_tiny, fast_small, fast_base
-
-
 def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
     """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
     Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
@@ -299,8 +288,8 @@ def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST:
         textnetfast_tiny,
         # change ignore keys
         ignore_keys=[
-            "classifier.6.weight",
-            "classifier.6.bias",
+            "classifier.final.conv.weight",
+            "classifier.final.conv.bias",
         ],
         **kwargs,
     )
@@ -329,8 +318,8 @@ def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST:
         textnetfast_tiny,
         # change ignore keys
         ignore_keys=[
-            "classifier.6.weight",
-            "classifier.6.bias",
+            "classifier.final.conv.weight",
+            "classifier.final.conv.bias",
         ],
         **kwargs,
     )
@@ -359,8 +348,107 @@ def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST:
         textnetfast_tiny,
         # change ignore keys
         ignore_keys=[
-            "classifier.6.weight",
-            "classifier.6.bias",
+            "classifier.final.conv.weight",
+            "classifier.final.conv.bias",
         ],
         **kwargs,
     )
+
+
+# verifier que le code fonction; cest le code de https://github.com/czczup/FAST/blob/main/models/loss/dice_loss.py
+# faire en sorte d'inserer dans le code le selected_masks
+def multiclass_dice_loss(inputs, targets, num_classes, loss_weight=1.0):
+    # Convert targets to one-hot encoding
+    targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float()
+
+    # Calculate intersection and union
+    intersection = torch.sum(inputs * targets, dim=(2, 3))
+    union = torch.sum(inputs, dim=(2, 3)) + torch.sum(targets, dim=(2, 3))
+
+    # Calculate Dice coefficients for each class
+    dice_coeffs = (2.0 * intersection + 1e-5) / (union + 1e-5)
+
+    # Calculate the average Dice loss across all classes
+    dice_loss = 1.0 - torch.mean(dice_coeffs)
+
+    return loss_weight * dice_loss
+
+
+# simplify emb_loss_v2
+def emb_loss_v2(emb, instance, kernel, training_mask):
+    training_mask = (training_mask > 0.5).long()
+    kernel = (kernel > 0.5).long()
+    instance = instance * training_mask
+    instance_kernel = (instance * kernel).view(-1)
+    instance = instance.view(-1)
+    emb = emb.view(4, -1)
+
+    unique_labels, unique_ids = torch.unique(instance_kernel, sorted=True, return_inverse=True)
+    num_instance = unique_labels.size(0)
+    if num_instance <= 1:
+        return 0
+
+    emb_mean = emb.new_zeros((4, num_instance), dtype=torch.float32)
+    for i, lb in enumerate(unique_labels):
+        if lb == 0:
+            continue
+        ind_k = instance_kernel == lb
+        emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1)
+
+    l_agg = emb.new_zeros(num_instance, dtype=torch.float32)  # bug
+    for i, lb in enumerate(unique_labels):
+        if lb == 0:
+            continue
+        ind = instance == lb
+        emb_ = emb[:, ind]
+        dist = (emb_ - emb_mean[:, i : i + 1]).norm(p=2, dim=0)
+        dist = F.relu(dist - 0.5) ** 2
+        l_agg[i] = torch.mean(torch.log(dist + 1.0))
+    l_agg = torch.mean(l_agg[1:])
+
+    if num_instance > 2:
+        emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1)
+        emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view(-1, 4)
+
+        mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view(-1, 1).repeat(1, 4)
+        mask = mask.view(num_instance, num_instance, -1)
+        mask[0, :, :] = 0
+        mask[:, 0, :] = 0
+        mask = mask.view(num_instance * num_instance, -1)
+
+        dist = emb_interleave - emb_band
+        dist = dist[mask > 0].view(-1, 4).norm(p=2, dim=1)
+        dist = F.relu(2 * 1.5 - dist) ** 2
+
+        l_dis = [torch.log(dist + 1.0)]
+        emb_bg = emb[:, instance == 0].view(4, -1)
+        if emb_bg.size(1) > 100:
+            rand_ind = np.random.permutation(emb_bg.size(1))[:100]
+            emb_bg = emb_bg[:, rand_ind]
+        if emb_bg.size(1) > 0:
+            for i, lb in enumerate(unique_labels):
+                if lb == 0:
+                    continue
+                dist = (emb_bg - emb_mean[:, i : i + 1]).norm(p=2, dim=0)
+                dist = F.relu(2 * 1.5 - dist) ** 2
+                l_dis_bg = torch.mean(torch.log(dist + 1.0), 0, keepdim=True)
+                l_dis.append(l_dis_bg)
+        l_dis = torch.mean(torch.cat(l_dis))
+    else:
+        l_dis = 0
+    l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001
+    loss = l_agg + l_dis + l_reg
+    return loss
+
+    def forward(self, emb, instance, kernel, training_mask, reduce=True):
+        loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32)
+
+        for i in range(loss_batch.size(0)):
+            loss_batch[i] = self.forward_single(emb[i], instance[i], kernel[i], training_mask[i])
+
+        loss_batch = 0.25 * loss_batch
+
+        if reduce:
+            loss_batch = torch.mean(loss_batch)
+
+        return loss_batch
diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py
index 3504ad9ca5..ae1f4ebed3 100644
--- a/doctr/models/modules/layers/pytorch.py
+++ b/doctr/models/modules/layers/pytorch.py
@@ -216,7 +216,7 @@ def switch_to_test(self):
             groups=self.main_conv.groups,
             bias=True,
         )
-        self.fused_conv.weight.data = kernel  # type ignore[operator]
+        self.fused_conv.weight.data = kernel
         self.fused_conv.bias.data = bias
         for para in self.fused_conv.parameters():
             para.detach_()

From 018a3c9d7f0fae47337e5de5d617453e98b67170 Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Wed, 13 Sep 2023 16:55:46 +0200
Subject: [PATCH 43/44] [skip ci] forward and compute_loss seems to be ok, need
 postprocessor implementation of fast torche model

---
 doctr/models/detection/fast/pytorch.py | 110 +++++++++++++++++++++----
 1 file changed, 94 insertions(+), 16 deletions(-)

diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index a5f008f2ab..8ba314d962 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -10,6 +10,8 @@
 import torch.nn as nn
 from torch.nn import functional as F
 
+from PIL import Image
+
 from doctr.file_utils import CLASS_NAME
 from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny
 from doctr.models.modules.layers.pytorch import ConvLayer, RepConvLayer
@@ -36,7 +38,7 @@
     },
 }
 
-# implement FastPostProcessing class
+# implement FastPostProcessing class with get_results head class
 
 
 class FAST(nn.Module):
@@ -81,27 +83,22 @@ def forward(
         return_model_output: bool = False,
         return_preds: bool = False,
     ) -> Dict[str, torch.Tensor]:
-        # MODIFIER LES PARAMETRES CI-DESSOUS POUR LES PARAMETRES
-        # gt_kernels=None, training_masks=None, gt_instances=None, img_metas=None,
 
-        out: Dict[str, Any] = {}
-        if not self.training:
-            torch.cuda.synchronize()
-        feats = self.feat_extractor(x)
-        if not self.training:
-            torch.cuda.synchronize()
+        x, gt_texts, gt_kernels, training_masks, gt_instances, img_metas = self.prepare_data(x, target)
+
+        feats = self.backbone(x)  
         logits = self.fpn(feats)
-        if not self.training:
-            torch.cuda.synchronize()
         logits = self.classifier(logits)
         logits = self._upsample(logits, x.size(), scale=1)
+
+        out: Dict[str, Any] = {}
         if self.exportable:
             out["logits"] = logits
             return out
 
         if return_model_output or target is None or return_preds:
             prob_map = torch.sigmoid(logits)
-
+            
         if return_model_output:
             out["out_map"] = prob_map
 
@@ -109,11 +106,13 @@ def forward(
             # Post-process boxes
             out["preds"] = [
                 dict(zip(self.class_names, preds))
-                for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy())
+                for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy(), img_metas, cfg, scale=2)
             ]
+
         if target is not None:
-            loss = self.compute_loss(logits, target)
+            loss = self.compute_loss(logits, gt_texts, gt_kernels, training_masks, gt_instances)
             out["loss"] = loss
+
         return out
 
     def compute_loss(self, out_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
@@ -152,7 +151,85 @@ def _upsample(self, x, size, scale=1):
         _, _, H, W = size
         return F.interpolate(x, size=(H // scale, W // scale), mode="bilinear")
 
-
+    def prepare_data(self,
+        x: torch.Tensor,
+        target: Optional[List[np.ndarray]] = None):
+
+        target = target[:self.num_classes]
+        gt_instance = np.zeros(x.shape[0:2], dtype='uint8')
+        training_mask = np.ones(x.shape[0:2], dtype='uint8')
+
+        if target.shape[0] > 0:
+            target = np.reshape(target * ([x.shape[1], x.shape[0]] * 4),
+                                (target.shape[0], -1, 2)).astype('int32')
+            for i in range(target.shape[0]):
+                cv2.drawContours(gt_instance, [target[i]], -1, i + 1, -1)
+
+        gt_kernels = np.array([np.zeros(x.shape[0:2], dtype='uint8')] * len(target)) # [instance_num, h, w]
+        gt_kernel = self.min_pooling(gt_kernels)
+
+        shrink_kernel_scale = 0.1
+        gt_kernel_shrinked = np.zeros(x.shape[0:2], dtype='uint8')
+        kernel_target = shrink(target, shrink_kernel_scale)
+        
+        for i in range(target.shape[0]):
+            cv2.drawContours(gt_kernel_shrinked, [kernel_target[i]], -1, 1, -1)
+        gt_kernel = np.maximum(gt_kernel, gt_kernel_shrinked)
+
+        gt_text = gt_instance.copy()
+        gt_text[gt_text > 0] = 1
+
+        x = Image.fromarray(x)
+        
+        img_meta = dict(
+            org_img_size=np.array(img.shape[:2])
+            img_size=np.array(img.shape[:2]),
+            filename=filename))
+
+        img = scale_aligned_short(img, self.short_size)
+        x = x.convert('RGB')
+
+        x = transforms.ToTensor()(x)
+        x = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
+
+        return x, \ 
+               torch.from_numpy(gt_text).long(), \
+               torch.from_numpy(gt_kernel).long(), \
+               torch.from_numpy(training_mask).long(), \
+               torch.from_numpy(gt_instance).long()\
+               img_meta
+        
+    # simplify this method
+    def min_pooling(self, input):
+        input = torch.tensor(input, dtype=torch.float)
+        temp = input.sum(dim=0).to(torch.uint8)
+        overlap = (temp > 1).to(torch.float32).unsqueeze(0).unsqueeze(0)
+        overlap = self.overlap_pool(overlap).squeeze(0).squeeze(0)
+
+        B = input.size(0)
+        h_sum = input.sum(dim=2) > 0
+        
+        h_sum_ = h_sum.long() * torch.arange(h_sum.shape[1], 0, -1)
+        h_min = torch.argmax(h_sum_, 1, keepdim=True)
+        h_sum_ = h_sum.long() * torch.arange(1, h_sum.shape[1] + 1)
+        h_max = torch.argmax(h_sum_, 1, keepdim=True)
+
+        w_sum = input.sum(dim=1) > 0
+        w_sum_ = w_sum.long() * torch.arange(w_sum.shape[1], 0, -1)
+        w_min = torch.argmax(w_sum_, 1, keepdim=True)
+        w_sum_ = w_sum.long() * torch.arange(1, w_sum.shape[1] + 1)
+        w_max = torch.argmax(w_sum_, 1, keepdim=True)
+
+        for i in range(B):
+            region = input[i:i + 1, h_min[i]:h_max[i] + 1, w_min[i]:w_max[i] + 1]
+            region = self.pad(region)
+            region = -self.pooling(-region)
+            input[i:i + 1, h_min[i]:h_max[i] + 1, w_min[i]:w_max[i] + 1] = region
+
+        x = input.sum(dim=0).to(torch.uint8)
+        x[overlap > 0] = 0  # overlapping regions
+        return x.numpy()
+       
 class FASTHead(nn.Module):
     def __init__(self):
         super(FASTHead, self).__init__()
@@ -239,7 +316,8 @@ def _fast(
     **kwargs: Any,
 ) -> FAST:
     pretrained_backbone = pretrained_backbone and not pretrained
-
+ 
+    # corriger l'encapsulation du backbon neck et head
     backbone = backbone_fn(pretrained_backbone)
     FASTNeck()
     FASTHead()

From 5de4fff20360723e95a397442217ba0b8f820a7f Mon Sep 17 00:00:00 2001
From: nikokks <playe.nicolas@gmail.com>
Date: Thu, 14 Sep 2023 16:08:34 +0200
Subject: [PATCH 44/44] [skip ci] advancements in Fast torch model forward
 method

---
 doctr/models/detection/__init__.py        |  1 +
 doctr/models/detection/fast/pytorch.py    | 27 ++++++++++++-----------
 doctr/models/detection/zoo.py             |  3 +++
 tests/pytorch/test_models_detection_pt.py |  7 ++++++
 4 files changed, 25 insertions(+), 13 deletions(-)

diff --git a/doctr/models/detection/__init__.py b/doctr/models/detection/__init__.py
index e2fafbadba..702ee501d8 100644
--- a/doctr/models/detection/__init__.py
+++ b/doctr/models/detection/__init__.py
@@ -1,3 +1,4 @@
 from .differentiable_binarization import *
 from .linknet import *
 from .zoo import *
+from .fast import *
diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py
index 8ba314d962..4dba5c278b 100644
--- a/doctr/models/detection/fast/pytorch.py
+++ b/doctr/models/detection/fast/pytorch.py
@@ -9,14 +9,13 @@
 import torch
 import torch.nn as nn
 from torch.nn import functional as F
-
 from PIL import Image
 
 from doctr.file_utils import CLASS_NAME
 from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny
 from doctr.models.modules.layers.pytorch import ConvLayer, RepConvLayer
 from doctr.utils.metrics import box_iou
-
+import cv2
 from ...utils import load_pretrained_params
 from .base import FastPostProcessor
 
@@ -63,7 +62,9 @@ def __init__(
         self.fpn = FASTNeck()
         self.classifier = FASTHead()
         self.postprocessor = FastPostProcessor(assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh)
-
+        self.overlap_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
+        self.pooling = nn.MaxPool2d(kernel_size=9, stride=1)
+        self.pad = nn.ZeroPad2d(padding=(9 - 1) // 2)
         for n, m in self.named_modules():
             # Don't override the initialization of the backbone
             if n.startswith("feat_extractor."):
@@ -83,7 +84,7 @@ def forward(
         return_model_output: bool = False,
         return_preds: bool = False,
     ) -> Dict[str, torch.Tensor]:
-
+        
         x, gt_texts, gt_kernels, training_masks, gt_instances, img_metas = self.prepare_data(x, target)
 
         feats = self.backbone(x)  
@@ -154,8 +155,9 @@ def _upsample(self, x, size, scale=1):
     def prepare_data(self,
         x: torch.Tensor,
         target: Optional[List[np.ndarray]] = None):
-
-        target = target[:self.num_classes]
+        
+        target = np.array([dico['words'] for dico in target[:self.num_classes]]).reshape(-1,1)
+       
         gt_instance = np.zeros(x.shape[0:2], dtype='uint8')
         training_mask = np.ones(x.shape[0:2], dtype='uint8')
 
@@ -182,9 +184,8 @@ def prepare_data(self,
         x = Image.fromarray(x)
         
         img_meta = dict(
-            org_img_size=np.array(img.shape[:2])
-            img_size=np.array(img.shape[:2]),
-            filename=filename))
+            org_img_size=np.array(img.shape[:2]),
+            img_size=np.array(img.shape[:2]))
 
         img = scale_aligned_short(img, self.short_size)
         x = x.convert('RGB')
@@ -192,11 +193,10 @@ def prepare_data(self,
         x = transforms.ToTensor()(x)
         x = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x)
 
-        return x, \ 
-               torch.from_numpy(gt_text).long(), \
-               torch.from_numpy(gt_kernel).long(), \
+        return x, torch.from_numpy(gt_text).long(),  \
+               torch.from_numpy(gt_kernel).long(),  \
                torch.from_numpy(training_mask).long(), \
-               torch.from_numpy(gt_instance).long()\
+               torch.from_numpy(gt_instance).long(),  \
                img_meta
         
     # simplify this method
@@ -530,3 +530,4 @@ def forward(self, emb, instance, kernel, training_mask, reduce=True):
             loss_batch = torch.mean(loss_batch)
 
         return loss_batch
+
diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py
index a07febdf27..c712df2aee 100644
--- a/doctr/models/detection/zoo.py
+++ b/doctr/models/detection/zoo.py
@@ -28,6 +28,9 @@
         "linknet_resnet18",
         "linknet_resnet34",
         "linknet_resnet50",
+        "fast_tiny",
+        "fast_small",
+        "fast_base",
     ]
     ROT_ARCHS = ["db_resnet50_rotation"]
 
diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py
index 39eae65168..d2f4bb8885 100644
--- a/tests/pytorch/test_models_detection_pt.py
+++ b/tests/pytorch/test_models_detection_pt.py
@@ -23,6 +23,10 @@
         ["linknet_resnet18", (3, 512, 512), (1, 512, 512), True],
         ["linknet_resnet34", (3, 512, 512), (1, 512, 512), True],
         ["linknet_resnet50", (3, 512, 512), (1, 512, 512), True],
+        ["fast_tiny", (3, 512, 512), (1, 512, 512), True],
+        ["fast_small", (3, 512, 512), (1, 512, 512), True],
+        ["fast_base", (3, 512, 512), (1, 512, 512), True],
+        
     ],
 )
 def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode):
@@ -125,6 +129,9 @@ def test_dilate():
         ["linknet_resnet18", (3, 512, 512), (1, 512, 512)],
         ["linknet_resnet34", (3, 512, 512), (1, 512, 512)],
         ["linknet_resnet50", (3, 512, 512), (1, 512, 512)],
+        ["fast_tiny", (3, 512, 512), (1, 512, 512), True],
+        ["fast_small", (3, 512, 512), (1, 512, 512), True],
+        ["fast_base", (3, 512, 512), (1, 512, 512), True],
     ],
 )
 def test_models_onnx_export(arch_name, input_shape, output_size):