Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detection Model TextNet (TF+PT) #1292

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
03c5834
adding TextNet backbone : first try
nikokks Aug 29, 2023
ceeca80
second try: uploading TextNet backbone
nikokks Aug 29, 2023
151e65f
correcting some syntax for textnet backbone
nikokks Aug 29, 2023
57d33f8
changing import layer for textnet
nikokks Aug 29, 2023
dae10b2
renaming textnet to textnetfast
nikokks Aug 29, 2023
2465696
adding textnetfast test and ran it good
nikokks Aug 29, 2023
ab5b5b3
correcting some stuff to run TextNetFast : not working for the moment
nikokks Aug 29, 2023
564f9eb
run test_classification_pytorch OK
nikokks Aug 29, 2023
7f48b06
first commit of textnetfast model in tensorflow
nikokks Aug 29, 2023
346bcd4
ending implementing tensorflow textnet classification model => go for…
nikokks Aug 29, 2023
eb9299e
some changes + make style + make quality
nikokks Aug 30, 2023
fd3d85b
some changes
nikokks Aug 30, 2023
10e7e05
some changes
nikokks Aug 30, 2023
d49c3f2
some changes3
nikokks Aug 30, 2023
6599951
[skip ci] some changes 4
nikokks Aug 30, 2023
491ddd5
[skip ci] some changes 5
nikokks Aug 30, 2023
2f2e769
[skip ci] some changes 5
nikokks Aug 30, 2023
f315122
[skip ci] removing ingore_keys in tf textnet model
nikokks Aug 31, 2023
ad1bc73
[skip ci] removing ingore_keys in tf textnet model
nikokks Aug 31, 2023
0bf73b3
[skip ci] first layers of the model to a single block
nikokks Aug 31, 2023
b131218
[skip ci] creating blocks in layers
nikokks Aug 31, 2023
f9a63c0
[skip ci] correction of some errors in make quality
nikokks Aug 31, 2023
064ca6b
[skip ci] adding fast inference fo textnet model
nikokks Sep 1, 2023
53c144b
[skip ci] some changes 6
nikokks Sep 3, 2023
774138d
[skip ci] override eval method for speed up eval mode of TextNetFast
nikokks Sep 4, 2023
a7e8d73
[skip ci] override eval and train mode for textNetFast done
nikokks Sep 4, 2023
ae5f7c4
[skip ci] make style and make quality Done
nikokks Sep 4, 2023
7c0bba3
[skip ci] changing eval and train method to switch repconvlayer
nikokks Sep 4, 2023
6814718
[skip ci] changing eval and train method to switch repconvlayer setti…
nikokks Sep 4, 2023
eac952e
[skip ci] TextNetFast pytorch model done
nikokks Sep 5, 2023
0ef122f
TextNetFast model implemented in torch
nikokks Sep 5, 2023
dcd2ece
adding textNetFast tensorflow implementation
nikokks Sep 8, 2023
a8ac914
starting to solving eval mode of textnetFadt model
nikokks Sep 9, 2023
48fb8e5
[skip ci] Last modification for switch train to eval mode for textnet…
nikokks Sep 9, 2023
77d78b6
[skip ci] dleting tensorflow textnetfast code for futher integreation
nikokks Sep 10, 2023
d7b9ca2
[skip ci] first commit on adding neck,head and fast model in torch
nikokks Sep 10, 2023
2e1ea2f
[skip ci] backbone+neck+head of FAst torch ready, remains fast class
nikokks Sep 10, 2023
e83b2ac
[skip ci] correcting some stuff for Fast Torch Model on style and qua…
nikokks Sep 11, 2023
e3aa43f
[skip ci] correcting some stuff for convlayer
nikokks Sep 11, 2023
cd35227
[skip ci] correcting some stuff for Fast Torch Model on style and qua…
nikokks Sep 13, 2023
6a5b5a6
[skip ci] implementation of Fast torch model update (working(init), n…
nikokks Sep 13, 2023
1bc5ba4
[skip ci] update losses for Fast Torch model
nikokks Sep 13, 2023
018a3c9
[skip ci] forward and compute_loss seems to be ok, need postprocessor…
nikokks Sep 13, 2023
5de4fff
[skip ci] advancements in Fast torch model forward method
nikokks Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Untitled.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 1 addition & 3 deletions doctr/models/artefacts/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def __init__(
) -> None:
self.n_faces = n_faces
# Instantiate classifier
self.detector = cv2.CascadeClassifier(
cv2.data.haarcascades + "haarcascade_frontalface_default.xml" # type: ignore[attr-defined]
)
self.detector = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")

def extra_repr(self) -> str:
return f"n_faces={self.n_faces}"
Expand Down
1 change: 1 addition & 0 deletions doctr/models/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .magc_resnet import *
from .vit import *
from .zoo import *
from .textnet_fast import *
6 changes: 6 additions & 0 deletions doctr/models/classification/textnet_fast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available, is_torch_available

if is_tf_available():
from .tensorflow import *

Check notice on line 4 in doctr/models/classification/textnet_fast/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/classification/textnet_fast/__init__.py#L4

'.tensorflow.*' imported but unused (F401)
elif is_torch_available():
from .pytorch import *

Check notice on line 6 in doctr/models/classification/textnet_fast/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/classification/textnet_fast/__init__.py#L6

'.pytorch.*' imported but unused (F401)
327 changes: 327 additions & 0 deletions doctr/models/classification/textnet_fast/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
# Copyright (C) 2021-2023, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


from copy import deepcopy
from typing import Any, Dict, List, Optional, Union

import torch.nn as nn

from doctr.datasets import VOCABS
from doctr.models.modules.layers.pytorch import RepConvLayer
from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence
from doctr.models.utils.pytorch import (
fuse_module,
rep_model_convert,
rep_model_convert_deploy,
rep_model_unconvert,
unfuse_module,
)

from ...utils import load_pretrained_params

__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"]

default_cfgs: Dict[str, Dict[str, Any]] = {
"textnetfast_tiny": {
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
"textnetfast_small": {
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
"textnetfast_base": {
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": None,
},
}


class TextNetFast(nn.Sequential):
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.

Args:
stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1
stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2
stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3
stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4
include_top (bool, optional): Whether to include the classifier head. Defaults to True.
num_classes (int, optional): Number of output classes. Defaults to 1000.
cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None.
"""

def __init__(
self,
stage1: List[Dict[str, Union[int, List[int]]]],
stage2: List[Dict[str, Union[int, List[int]]]],
stage3: List[Dict[str, Union[int, List[int]]]],
stage4: List[Dict[str, Union[int, List[int]]]],
include_top: bool = True,
num_classes: int = 1000,
cfg: Optional[Dict[str, Any]] = None,
) -> None:
_layers: List[Any]
super().__init__()
first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2)
self.first_conv = nn.Sequential(*first_conv)
_layers = [self.first_conv]

for stage in [stage1, stage2, stage3, stage4]:
nikokks marked this conversation as resolved.
Show resolved Hide resolved
self.stage_ = nn.Sequential(*[RepConvLayer(**params) for params in stage]) # type: ignore[arg-type]
_layers.extend([self.stage_])

if include_top:
classif_block = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(1),
nn.Linear(512, num_classes, bias=True),
]
classif_block_ = nn.Sequential(*nn.ModuleList(classif_block))
_layers.extend([classif_block_])

super().__init__(*_layers)
self.cfg = cfg

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def eval(self, mode=False):
self = rep_model_convert(self)
self = fuse_module(self)
for param in self.parameters():
param.requires_grad = mode
self.training = mode
return self

def train(self, mode=True):
self = unfuse_module(self)
self = rep_model_unconvert(self)
for param in self.parameters():
param.requires_grad = mode
self.training = mode
return self

def test(self, mode=False):
self = rep_model_convert_deploy(self)
self = fuse_module(self)
for param in self.parameters():
param.requires_grad = mode
self.training = mode
return self


def _textnetfast(
arch: str,
pretrained: bool,
arch_fn,
ignore_keys: Optional[List[str]] = None,
**kwargs: Any,
) -> TextNetFast:
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])

_cfg = deepcopy(default_cfgs[arch])
_cfg["num_classes"] = kwargs["num_classes"]
_cfg["classes"] = kwargs["classes"]
kwargs.pop("classes")

# Build the model
model = arch_fn(**kwargs)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

model.cfg = _cfg

return model


def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.

>>> import torch
>>> from doctr.models import textnetfast_tiny
>>> model = textnetfast_tiny(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)

Args:
pretrained: boolean, True if model is pretrained

Returns:
A TextNet model
"""

return _textnetfast(
"textnetfast_tiny",
pretrained,
TextNetFast,
stage1=[
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
],
stage2=[
{"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
],
stage3=[
{"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
],
stage4=[
{"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1},
],
ignore_keys=["4.3.weight", "4.3.bias"],
**kwargs,
)


def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.

>>> import torch
>>> from doctr.models import textnetfast_small
>>> model = textnetfast_small(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)

Args:
pretrained: boolean, True if model is pretrained

Returns:
A TextNetFast model
"""

return _textnetfast(
"textnetfast_small",
pretrained,
TextNetFast,
stage1=[
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
],
stage2=[
{"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
],
stage3=[
{"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
],
stage4=[
{"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
],
ignore_keys=["4.3.weight", "4.3.bias"],
**kwargs,
)


def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast:
"""Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with
Minimalist Kernel Representation" <https://arxiv.org/abs/2111.02394>`_.
Implementation based on the official Pytorch implementation: <https://github.com/czczup/FAST>`_.

>>> import torch
>>> from doctr.models import textnetfast_base
>>> model = textnetfast_base(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)

Args:
pretrained: boolean, True if model is pretrained

Returns:
A TextNetFast model
"""

return _textnetfast(
"textnetfast_base",
pretrained,
TextNetFast,
stage1=[
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1},
],
stage2=[
{"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1},
],
stage3=[
{"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1},
{"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1},
],
stage4=[
{"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2},
{"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1},
{"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1},
],
ignore_keys=["4.3.weight", "4.3.bias"],
**kwargs,
)
3 changes: 3 additions & 0 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
"vgg16_bn_r",
"vit_s",
"vit_b",
"textnetfast_tiny",
"textnetfast_small",
"textnetfast_base",
]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"]

Expand Down
1 change: 1 addition & 0 deletions doctr/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .differentiable_binarization import *
from .linknet import *
from .zoo import *
from .fast import *

Check notice on line 4 in doctr/models/detection/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/detection/__init__.py#L4

'.fast.*' imported but unused (F401)
6 changes: 6 additions & 0 deletions doctr/models/detection/fast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from doctr.file_utils import is_tf_available, is_torch_available

if is_tf_available():
from .tensorflow import *

Check notice on line 4 in doctr/models/detection/fast/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/detection/fast/__init__.py#L4

'.tensorflow.*' imported but unused (F401)
elif is_torch_available():
from .pytorch import * # type: ignore[assignment]

Check notice on line 6 in doctr/models/detection/fast/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/detection/fast/__init__.py#L6

'.pytorch.*' imported but unused (F401)
Loading