diff --git a/Untitled.ipynb b/Untitled.ipynb new file mode 100644 index 000000000..363fcab7e --- /dev/null +++ b/Untitled.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doctr/models/artefacts/face.py b/doctr/models/artefacts/face.py index f79200a07..7f8d2d4d7 100644 --- a/doctr/models/artefacts/face.py +++ b/doctr/models/artefacts/face.py @@ -28,9 +28,7 @@ def __init__( ) -> None: self.n_faces = n_faces # Instantiate classifier - self.detector = cv2.CascadeClassifier( - cv2.data.haarcascades + "haarcascade_frontalface_default.xml" # type: ignore[attr-defined] - ) + self.detector = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml") def extra_repr(self) -> str: return f"n_faces={self.n_faces}" diff --git a/doctr/models/classification/__init__.py b/doctr/models/classification/__init__.py index 72e68b78d..e1b303ef2 100644 --- a/doctr/models/classification/__init__.py +++ b/doctr/models/classification/__init__.py @@ -4,3 +4,4 @@ from .magc_resnet import * from .vit import * from .zoo import * +from .textnet_fast import * diff --git a/doctr/models/classification/textnet_fast/__init__.py b/doctr/models/classification/textnet_fast/__init__.py new file mode 100644 index 000000000..64556e403 --- /dev/null +++ b/doctr/models/classification/textnet_fast/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * diff --git a/doctr/models/classification/textnet_fast/pytorch.py b/doctr/models/classification/textnet_fast/pytorch.py new file mode 100644 index 000000000..bce89525d --- /dev/null +++ b/doctr/models/classification/textnet_fast/pytorch.py @@ -0,0 +1,327 @@ +# Copyright (C) 2021-2023, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Union + +import torch.nn as nn + +from doctr.datasets import VOCABS +from doctr.models.modules.layers.pytorch import RepConvLayer +from doctr.models.utils.pytorch import conv_sequence_pt as conv_sequence +from doctr.models.utils.pytorch import ( + fuse_module, + rep_model_convert, + rep_model_convert_deploy, + rep_model_unconvert, + unfuse_module, +) + +from ...utils import load_pretrained_params + +__all__ = ["textnetfast_tiny", "textnetfast_small", "textnetfast_base"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + "textnetfast_tiny": { + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": None, + }, + "textnetfast_small": { + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": None, + }, + "textnetfast_base": { + "input_shape": (3, 32, 32), + "classes": list(VOCABS["french"]), + "url": None, + }, +} + + +class TextNetFast(nn.Sequential): + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + Args: + stage1 (Dict[str, Union[int, List[int]]]): Configuration for stage 1 + stage2 (Dict[str, Union[int, List[int]]]): Configuration for stage 2 + stage3 (Dict[str, Union[int, List[int]]]): Configuration for stage 3 + stage4 (Dict[str, Union[int, List[int]]]): Configuration for stage 4 + include_top (bool, optional): Whether to include the classifier head. Defaults to True. + num_classes (int, optional): Number of output classes. Defaults to 1000. + cfg (Optional[Dict[str, Any]], optional): Additional configuration. Defaults to None. + """ + + def __init__( + self, + stage1: List[Dict[str, Union[int, List[int]]]], + stage2: List[Dict[str, Union[int, List[int]]]], + stage3: List[Dict[str, Union[int, List[int]]]], + stage4: List[Dict[str, Union[int, List[int]]]], + include_top: bool = True, + num_classes: int = 1000, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + _layers: List[Any] + super().__init__() + first_conv = conv_sequence(in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2) + self.first_conv = nn.Sequential(*first_conv) + _layers = [self.first_conv] + + for stage in [stage1, stage2, stage3, stage4]: + self.stage_ = nn.Sequential(*[RepConvLayer(**params) for params in stage]) # type: ignore[arg-type] + _layers.extend([self.stage_]) + + if include_top: + classif_block = [ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(1), + nn.Linear(512, num_classes, bias=True), + ] + classif_block_ = nn.Sequential(*nn.ModuleList(classif_block)) + _layers.extend([classif_block_]) + + super().__init__(*_layers) + self.cfg = cfg + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def eval(self, mode=False): + self = rep_model_convert(self) + self = fuse_module(self) + for param in self.parameters(): + param.requires_grad = mode + self.training = mode + return self + + def train(self, mode=True): + self = unfuse_module(self) + self = rep_model_unconvert(self) + for param in self.parameters(): + param.requires_grad = mode + self.training = mode + return self + + def test(self, mode=False): + self = rep_model_convert_deploy(self) + self = fuse_module(self) + for param in self.parameters(): + param.requires_grad = mode + self.training = mode + return self + + +def _textnetfast( + arch: str, + pretrained: bool, + arch_fn, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> TextNetFast: + kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg["num_classes"] = kwargs["num_classes"] + _cfg["classes"] = kwargs["classes"] + kwargs.pop("classes") + + # Build the model + model = arch_fn(**kwargs) + # Load pretrained parameters + if pretrained: + # The number of classes is not the same as the number of classes in the pretrained model => + # remove the last layer weights + _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + model.cfg = _cfg + + return model + + +def textnetfast_tiny(pretrained: bool = False, **kwargs: Any) -> TextNetFast: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import textnetfast_tiny + >>> model = textnetfast_tiny(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A TextNet model + """ + + return _textnetfast( + "textnetfast_tiny", + pretrained, + TextNetFast, + stage1=[ + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + ], + stage2=[ + {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}, + ], + stage3=[ + {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1}, + ], + stage4=[ + {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 3], "stride": 1}, + ], + ignore_keys=["4.3.weight", "4.3.bias"], + **kwargs, + ) + + +def textnetfast_small(pretrained: bool = False, **kwargs: Any) -> TextNetFast: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import textnetfast_small + >>> model = textnetfast_small(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A TextNetFast model + """ + + return _textnetfast( + "textnetfast_small", + pretrained, + TextNetFast, + stage1=[ + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}, + ], + stage2=[ + {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + ], + stage3=[ + {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + ], + stage4=[ + {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1}, + ], + ignore_keys=["4.3.weight", "4.3.bias"], + **kwargs, + ) + + +def textnetfast_base(pretrained: bool = False, **kwargs: Any) -> TextNetFast: + """Implements TextNet architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import textnetfast_base + >>> model = textnetfast_base(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A TextNetFast model + """ + + return _textnetfast( + "textnetfast_base", + pretrained, + TextNetFast, + stage1=[ + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 64, "out_channels": 64, "kernel_size": [3, 3], "stride": 1}, + ], + stage2=[ + {"in_channels": 64, "out_channels": 128, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 128, "out_channels": 128, "kernel_size": [3, 3], "stride": 1}, + ], + stage3=[ + {"in_channels": 128, "out_channels": 256, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 3], "stride": 1}, + {"in_channels": 256, "out_channels": 256, "kernel_size": [3, 1], "stride": 1}, + ], + stage4=[ + {"in_channels": 256, "out_channels": 512, "kernel_size": [3, 3], "stride": 2}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [3, 1], "stride": 1}, + {"in_channels": 512, "out_channels": 512, "kernel_size": [1, 3], "stride": 1}, + ], + ignore_keys=["4.3.weight", "4.3.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9ec80a261..c573373d7 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -27,6 +27,9 @@ "vgg16_bn_r", "vit_s", "vit_b", + "textnetfast_tiny", + "textnetfast_small", + "textnetfast_base", ] ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"] diff --git a/doctr/models/detection/__init__.py b/doctr/models/detection/__init__.py index e2fafbadb..702ee501d 100644 --- a/doctr/models/detection/__init__.py +++ b/doctr/models/detection/__init__.py @@ -1,3 +1,4 @@ from .differentiable_binarization import * from .linknet import * from .zoo import * +from .fast import * diff --git a/doctr/models/detection/fast/__init__.py b/doctr/models/detection/fast/__init__.py new file mode 100644 index 000000000..c7110f566 --- /dev/null +++ b/doctr/models/detection/fast/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py new file mode 100644 index 000000000..727e10e00 --- /dev/null +++ b/doctr/models/detection/fast/base.py @@ -0,0 +1,139 @@ +# Copyright (C) 2021-2023, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import List, Union + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from ..core import DetectionPostProcessor + +__all__ = ["FastPostProcessor"] + + +class FastPostProcessor(DetectionPostProcessor): + """Implements a post processor for LinkNet model. + + Args: + bin_thresh: threshold used to binzarized p_map at inference time + box_thresh: minimal objectness score to consider a box + assume_straight_pages: whether the inputs were expected to have horizontal text elements + """ + + def __init__( + self, + bin_thresh: float = 0.1, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + ) -> None: + super().__init__(box_thresh, bin_thresh, assume_straight_pages) + self.unclip_ratio = 1.2 + + def polygon_to_box( + self, + points: np.ndarray, + ) -> np.ndarray: + """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon + + Args: + points: The first parameter. + + Returns: + a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) + """ + if not self.assume_straight_pages: + # Compute the rectangle polygon enclosing the raw polygon + rect = cv2.minAreaRect(points) + points = cv2.boxPoints(rect) + # Add 1 pixel to correct cv2 approx + area = (rect[1][0] + 1) * (1 + rect[1][1]) + length = 2 * (rect[1][0] + rect[1][1]) + 2 + else: + poly = Polygon(points) + area = poly.area + length = poly.length + distance = area * self.unclip_ratio / length # compute distance to expand polygon + offset = pyclipper.PyclipperOffset() + offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + _points = offset.Execute(distance) + # Take biggest stack of points + idx = 0 + if len(_points) > 1: + max_size = 0 + for _idx, p in enumerate(_points): + if len(p) > max_size: + idx = _idx + max_size = len(p) + # We ensure that _points can be correctly casted to a ndarray + _points = [_points[idx]] + expanded_points: np.ndarray = np.asarray(_points) # expand polygon + if len(expanded_points) < 1: + return None # type: ignore[return-value] + return ( + cv2.boundingRect(expanded_points) + if self.assume_straight_pages + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes + + Args: + pred: Pred map from differentiable linknet output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + np tensor boxes for the bitmap, each box is a 6-element list + containing x, y, w, h, alpha, score for the box + """ + height, width = bitmap.shape[:2] + boxes: List[Union[np.ndarray, List[float]]] = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + _box = self.polygon_to_box(points) + else: + _box = self.polygon_to_box(np.squeeze(contour)) + + if self.assume_straight_pages: + # compute relative polygon to get rid of img shape + x, y, w, h = _box + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + # compute relative box to get rid of img shape + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py new file mode 100644 index 000000000..4dba5c278 --- /dev/null +++ b/doctr/models/detection/fast/pytorch.py @@ -0,0 +1,533 @@ +# Copyright (C) 2021-2023, Mindee. + +# This program is licensed under the Apache License 2.0. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from PIL import Image + +from doctr.file_utils import CLASS_NAME +from doctr.models.classification.textnet_fast.pytorch import textnetfast_tiny +from doctr.models.modules.layers.pytorch import ConvLayer, RepConvLayer +from doctr.utils.metrics import box_iou +import cv2 +from ...utils import load_pretrained_params +from .base import FastPostProcessor + +__all__ = ["fast_tiny", "fast_small", "fast_base"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + "fast_tiny": { + "input_shape": (3, 1024, 1024), + "url": None, + }, + "fast_small": { + "input_shape": (3, 1024, 1024), + "url": None, + }, + "fast_base": { + "input_shape": (3, 1024, 1024), + "url": None, + }, +} + +# implement FastPostProcessing class with get_results head class + + +class FAST(nn.Module): + def __init__( + self, + feat_extractor, + bin_thresh: float = 0.1, + head_chans: int = 32, + assume_straight_pages: bool = True, + exportable: bool = False, + cfg: Optional[Dict[str, Any]] = None, + class_names: List[str] = [CLASS_NAME], + ) -> None: + super().__init__() + self.class_names = class_names + self.num_classes = len(self.class_names) + self.cfg = cfg + self.exportable = exportable + self.assume_straight_pages = assume_straight_pages + self.feat_extractor = feat_extractor + self.feat_extractor.train() + self.fpn = FASTNeck() + self.classifier = FASTHead() + self.postprocessor = FastPostProcessor(assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh) + self.overlap_pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + self.pooling = nn.MaxPool2d(kernel_size=9, stride=1) + self.pad = nn.ZeroPad2d(padding=(9 - 1) // 2) + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith("feat_extractor."): + continue + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, torch.Tensor]: + + x, gt_texts, gt_kernels, training_masks, gt_instances, img_metas = self.prepare_data(x, target) + + feats = self.backbone(x) + logits = self.fpn(feats) + logits = self.classifier(logits) + logits = self._upsample(logits, x.size(), scale=1) + + out: Dict[str, Any] = {} + if self.exportable: + out["logits"] = logits + return out + + if return_model_output or target is None or return_preds: + prob_map = torch.sigmoid(logits) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes + out["preds"] = [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy(), img_metas, cfg, scale=2) + ] + + if target is not None: + loss = self.compute_loss(logits, gt_texts, gt_kernels, training_masks, gt_instances) + out["loss"] = loss + + return out + + def compute_loss(self, out_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor: + # IL MANQUE CES PARAMATRES (gt_kernels, training_masks, gt_instances) + + # output + kernels = out_map[:, 0, :, :] # 4*640*640 + texts = self._max_pooling(kernels, scale=1) # 4*640*640 + embs = out_map[:, 1:, :, :] # 4*4*640*640 + + # text loss + loss_text = multiclass_dice_loss(texts, target, self.num_classes, loss_weight=0.25) + iou_text = box_iou((texts > 0).long(), target) + losses = dict(loss_text=loss_text, iou_text=iou_text) + + # kernel loss + loss_kernel = multiclass_dice_loss(kernels, None, self.num_classes, loss_weight=1.0) + loss_kernel = torch.mean(loss_kernel, dim=0) + iou_kernel = box_iou((kernels > 0).long(), None) + losses.update(dict(loss_kernels=loss_kernel, iou_kernel=iou_kernel)) + + # auxiliary loss + loss_emb = emb_loss_v2(embs, None, None, None) + losses.update(dict(loss_emb=loss_emb)) + + return losses + + def _max_pooling(self, x, scale=1): + if scale == 1: + x = self.pooling_1s(x) + elif scale == 2: + x = self.pooling_2s(x) + return x + + def _upsample(self, x, size, scale=1): + _, _, H, W = size + return F.interpolate(x, size=(H // scale, W // scale), mode="bilinear") + + def prepare_data(self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None): + + target = np.array([dico['words'] for dico in target[:self.num_classes]]).reshape(-1,1) + + gt_instance = np.zeros(x.shape[0:2], dtype='uint8') + training_mask = np.ones(x.shape[0:2], dtype='uint8') + + if target.shape[0] > 0: + target = np.reshape(target * ([x.shape[1], x.shape[0]] * 4), + (target.shape[0], -1, 2)).astype('int32') + for i in range(target.shape[0]): + cv2.drawContours(gt_instance, [target[i]], -1, i + 1, -1) + + gt_kernels = np.array([np.zeros(x.shape[0:2], dtype='uint8')] * len(target)) # [instance_num, h, w] + gt_kernel = self.min_pooling(gt_kernels) + + shrink_kernel_scale = 0.1 + gt_kernel_shrinked = np.zeros(x.shape[0:2], dtype='uint8') + kernel_target = shrink(target, shrink_kernel_scale) + + for i in range(target.shape[0]): + cv2.drawContours(gt_kernel_shrinked, [kernel_target[i]], -1, 1, -1) + gt_kernel = np.maximum(gt_kernel, gt_kernel_shrinked) + + gt_text = gt_instance.copy() + gt_text[gt_text > 0] = 1 + + x = Image.fromarray(x) + + img_meta = dict( + org_img_size=np.array(img.shape[:2]), + img_size=np.array(img.shape[:2])) + + img = scale_aligned_short(img, self.short_size) + x = x.convert('RGB') + + x = transforms.ToTensor()(x) + x = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x) + + return x, torch.from_numpy(gt_text).long(), \ + torch.from_numpy(gt_kernel).long(), \ + torch.from_numpy(training_mask).long(), \ + torch.from_numpy(gt_instance).long(), \ + img_meta + + # simplify this method + def min_pooling(self, input): + input = torch.tensor(input, dtype=torch.float) + temp = input.sum(dim=0).to(torch.uint8) + overlap = (temp > 1).to(torch.float32).unsqueeze(0).unsqueeze(0) + overlap = self.overlap_pool(overlap).squeeze(0).squeeze(0) + + B = input.size(0) + h_sum = input.sum(dim=2) > 0 + + h_sum_ = h_sum.long() * torch.arange(h_sum.shape[1], 0, -1) + h_min = torch.argmax(h_sum_, 1, keepdim=True) + h_sum_ = h_sum.long() * torch.arange(1, h_sum.shape[1] + 1) + h_max = torch.argmax(h_sum_, 1, keepdim=True) + + w_sum = input.sum(dim=1) > 0 + w_sum_ = w_sum.long() * torch.arange(w_sum.shape[1], 0, -1) + w_min = torch.argmax(w_sum_, 1, keepdim=True) + w_sum_ = w_sum.long() * torch.arange(1, w_sum.shape[1] + 1) + w_max = torch.argmax(w_sum_, 1, keepdim=True) + + for i in range(B): + region = input[i:i + 1, h_min[i]:h_max[i] + 1, w_min[i]:w_max[i] + 1] + region = self.pad(region) + region = -self.pooling(-region) + input[i:i + 1, h_min[i]:h_max[i] + 1, w_min[i]:w_max[i] + 1] = region + + x = input.sum(dim=0).to(torch.uint8) + x[overlap > 0] = 0 # overlapping regions + return x.numpy() + +class FASTHead(nn.Module): + def __init__(self): + super(FASTHead, self).__init__() + self.conv = RepConvLayer(in_channels=512, out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1) + + self.final = ConvLayer( + kernel_size=1, + stride=1, + dilation=1, + groups=1, + bias=False, + has_shuffle=False, + in_channels=128, + out_channels=5, + use_bn=False, + act_func=None, + dropout_rate=0, + ops_order="weight", + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def forward(self, x): + x = self.conv(x) + x = self.final(x) + return x + + +class FASTNeck(nn.Module): + def __init__(self, reduce_layers=[64, 128, 256, 512]): + super(FASTNeck, self).__init__() + + self.reduce_layer1 = RepConvLayer( + in_channels=reduce_layers[0], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1 + ) + self.reduce_layer2 = RepConvLayer( + in_channels=reduce_layers[1], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1 + ) + self.reduce_layer3 = RepConvLayer( + in_channels=reduce_layers[2], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1 + ) + self.reduce_layer4 = RepConvLayer( + in_channels=reduce_layers[3], out_channels=128, kernel_size=[3, 3], stride=1, dilation=1, groups=1 + ) + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _upsample(self, x, y): + _, _, H, W = y.size() + return F.upsample(x, size=(H, W), mode="bilinear") + + def forward(self, x): + f1, f2, f3, f4 = x + f1 = self.reduce_layer1(f1) + f2 = self.reduce_layer2(f2) + f3 = self.reduce_layer3(f3) + f4 = self.reduce_layer4(f4) + + f2 = self._upsample(f2, f1) + f3 = self._upsample(f3, f1) + f4 = self._upsample(f4, f1) + f = torch.cat((f1, f2, f3, f4), 1) + return f + + +def _fast( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + pretrained_backbone: bool = True, + ignore_keys: Optional[List[str]] = None, + **kwargs: Any, +) -> FAST: + pretrained_backbone = pretrained_backbone and not pretrained + + # corriger l'encapsulation du backbon neck et head + backbone = backbone_fn(pretrained_backbone) + FASTNeck() + FASTHead() + + feat_extractor = backbone + + if not kwargs.get("class_names", None): + kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) + else: + kwargs["class_names"] = sorted(kwargs["class_names"]) + + # Build the model + model = FAST(feat_extractor=feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + # The number of class_names is not the same as the number of classes in the pretrained model => + # remove the layer weights + _ignore_keys = ( + ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None + ) + load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) + + return model + + +def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST: + """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import fast_tiny + >>> model = fast_tiny(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + return _fast( + "fast_tiny", + pretrained, + textnetfast_tiny, + # change ignore keys + ignore_keys=[ + "classifier.final.conv.weight", + "classifier.final.conv.bias", + ], + **kwargs, + ) + + +def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST: + """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import fast_small + >>> model = fast_small(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + return _fast( + "fast_small", + pretrained, + textnetfast_tiny, + # change ignore keys + ignore_keys=[ + "classifier.final.conv.weight", + "classifier.final.conv.bias", + ], + **kwargs, + ) + + +def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST: + """Fast architecture from `"FAST: Faster Arbitrarily-Shaped Text Detector with + Minimalist Kernel Representation" `_. + Implementation based on the official Pytorch implementation: `_. + + >>> import torch + >>> from doctr.models import fast_base + >>> model = fast_base(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + return _fast( + "fast_base", + pretrained, + textnetfast_tiny, + # change ignore keys + ignore_keys=[ + "classifier.final.conv.weight", + "classifier.final.conv.bias", + ], + **kwargs, + ) + + +# verifier que le code fonction; cest le code de https://github.com/czczup/FAST/blob/main/models/loss/dice_loss.py +# faire en sorte d'inserer dans le code le selected_masks +def multiclass_dice_loss(inputs, targets, num_classes, loss_weight=1.0): + # Convert targets to one-hot encoding + targets = F.one_hot(targets, num_classes=num_classes).permute(0, 3, 1, 2).float() + + # Calculate intersection and union + intersection = torch.sum(inputs * targets, dim=(2, 3)) + union = torch.sum(inputs, dim=(2, 3)) + torch.sum(targets, dim=(2, 3)) + + # Calculate Dice coefficients for each class + dice_coeffs = (2.0 * intersection + 1e-5) / (union + 1e-5) + + # Calculate the average Dice loss across all classes + dice_loss = 1.0 - torch.mean(dice_coeffs) + + return loss_weight * dice_loss + + +# simplify emb_loss_v2 +def emb_loss_v2(emb, instance, kernel, training_mask): + training_mask = (training_mask > 0.5).long() + kernel = (kernel > 0.5).long() + instance = instance * training_mask + instance_kernel = (instance * kernel).view(-1) + instance = instance.view(-1) + emb = emb.view(4, -1) + + unique_labels, unique_ids = torch.unique(instance_kernel, sorted=True, return_inverse=True) + num_instance = unique_labels.size(0) + if num_instance <= 1: + return 0 + + emb_mean = emb.new_zeros((4, num_instance), dtype=torch.float32) + for i, lb in enumerate(unique_labels): + if lb == 0: + continue + ind_k = instance_kernel == lb + emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1) + + l_agg = emb.new_zeros(num_instance, dtype=torch.float32) # bug + for i, lb in enumerate(unique_labels): + if lb == 0: + continue + ind = instance == lb + emb_ = emb[:, ind] + dist = (emb_ - emb_mean[:, i : i + 1]).norm(p=2, dim=0) + dist = F.relu(dist - 0.5) ** 2 + l_agg[i] = torch.mean(torch.log(dist + 1.0)) + l_agg = torch.mean(l_agg[1:]) + + if num_instance > 2: + emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1) + emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view(-1, 4) + + mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view(-1, 1).repeat(1, 4) + mask = mask.view(num_instance, num_instance, -1) + mask[0, :, :] = 0 + mask[:, 0, :] = 0 + mask = mask.view(num_instance * num_instance, -1) + + dist = emb_interleave - emb_band + dist = dist[mask > 0].view(-1, 4).norm(p=2, dim=1) + dist = F.relu(2 * 1.5 - dist) ** 2 + + l_dis = [torch.log(dist + 1.0)] + emb_bg = emb[:, instance == 0].view(4, -1) + if emb_bg.size(1) > 100: + rand_ind = np.random.permutation(emb_bg.size(1))[:100] + emb_bg = emb_bg[:, rand_ind] + if emb_bg.size(1) > 0: + for i, lb in enumerate(unique_labels): + if lb == 0: + continue + dist = (emb_bg - emb_mean[:, i : i + 1]).norm(p=2, dim=0) + dist = F.relu(2 * 1.5 - dist) ** 2 + l_dis_bg = torch.mean(torch.log(dist + 1.0), 0, keepdim=True) + l_dis.append(l_dis_bg) + l_dis = torch.mean(torch.cat(l_dis)) + else: + l_dis = 0 + l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001 + loss = l_agg + l_dis + l_reg + return loss + + def forward(self, emb, instance, kernel, training_mask, reduce=True): + loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32) + + for i in range(loss_batch.size(0)): + loss_batch[i] = self.forward_single(emb[i], instance[i], kernel[i], training_mask[i]) + + loss_batch = 0.25 * loss_batch + + if reduce: + loss_batch = torch.mean(loss_batch) + + return loss_batch + diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index a07febdf2..c712df2ae 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -28,6 +28,9 @@ "linknet_resnet18", "linknet_resnet34", "linknet_resnet50", + "fast_tiny", + "fast_small", + "fast_base", ] ROT_ARCHS = ["db_resnet50_rotation"] diff --git a/doctr/models/modules/layers/__init__.py b/doctr/models/modules/layers/__init__.py new file mode 100644 index 000000000..c7110f566 --- /dev/null +++ b/doctr/models/modules/layers/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[assignment] diff --git a/doctr/models/modules/layers/pytorch.py b/doctr/models/modules/layers/pytorch.py new file mode 100644 index 000000000..ae1f4ebed --- /dev/null +++ b/doctr/models/modules/layers/pytorch.py @@ -0,0 +1,364 @@ +from collections import OrderedDict +from typing import Any, Union + +import numpy as np +import torch +import torch.nn as nn + +__all__ = ["RepConvLayer"] + + +def get_same_padding(kernel_size): + if isinstance(kernel_size, tuple): + assert len(kernel_size) == 2, "invalid kernel size: %s" % kernel_size + p1 = get_same_padding(kernel_size[0]) + p2 = get_same_padding(kernel_size[1]) + return p1, p2 + assert isinstance(kernel_size, int), "kernel size should be either `int` or `tuple`" + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +def build_activation(act_func, inplace=True): + if act_func == "relu": + return nn.ReLU(inplace=inplace) + elif act_func == "relu6": + return nn.ReLU6(inplace=inplace) + elif act_func == "tanh": + return nn.Tanh() + elif act_func == "sigmoid": + return nn.Sigmoid() + elif act_func is None: + return None + else: + raise ValueError("do not support: %s" % act_func) + + +class RepConvLayer(nn.Module): + """Reparameterized Convolutional Layer""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[Any], + groups: int = 1, + deploy: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + dilation = kwargs.get("dilation", 1) + stride = kwargs.get("stride", 1) + kwargs.pop("padding", None) + kwargs.pop("bias", None) + + self.hor_conv, self.hor_bn = None, None + self.ver_conv, self.ver_bn = None, None + + padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2)) + + self.activation = nn.ReLU(inplace=True) + self.main_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + **kwargs, + ) + + self.main_bn = nn.BatchNorm2d(out_channels) + + if kernel_size[1] != 1: + self.ver_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(kernel_size[0], 1), + padding=(int(((kernel_size[0] - 1) * dilation) / 2), 0), + bias=False, + **kwargs, + ) + self.ver_bn = nn.BatchNorm2d(out_channels) + + if kernel_size[0] != 1: + self.hor_conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(1, kernel_size[1]), + padding=(0, int(((kernel_size[1] - 1) * dilation) / 2)), + bias=False, + **kwargs, + ) + self.hor_bn = nn.BatchNorm2d(out_channels) + + self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.groups = groups + self.deploy = deploy + + def forward(self, x: torch.Tensor) -> torch.Tensor: + main_outputs = self.main_bn(self.main_conv(x)) + + if self.ver_conv is not None and self.ver_bn is not None: + vertical_outputs = self.ver_bn(self.ver_conv(x)) + else: + vertical_outputs = 0 + + if self.hor_bn is not None and self.hor_conv is not None: + horizontal_outputs = self.hor_bn(self.hor_conv(x)) + else: + horizontal_outputs = 0 + + if self.rbr_identity is not None and self.ver_bn is not None: + id_out = self.rbr_identity(x) + else: + id_out = 0 + + return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out) + + def _identity_to_conv(self, identity): + if identity is None: + return 0, 0 + assert isinstance(identity, nn.BatchNorm2d) + if not hasattr(self, "id_tensor"): + input_dim = self.in_channels // self.groups + kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, 0, 0] = 1 + id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device) + self.id_tensor = self._pad_to_mxn_tensor(id_tensor) + kernel = self.id_tensor + running_mean = identity.running_mean + running_var = identity.running_var + gamma = identity.weight + beta = identity.bias + eps = identity.eps + std = (running_var + eps).sqrt() # type: ignore + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _fuse_bn_tensor(self, conv, bn): + kernel = conv.weight + kernel = self._pad_to_mxn_tensor(kernel) + running_mean = bn.running_mean + running_var = bn.running_var + gamma = bn.weight + beta = bn.bias + eps = bn.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def get_equivalent_kernel_bias(self): + kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.main_conv, self.main_bn) + if self.ver_conv is not None: + kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) + else: + kernel_mx1, bias_mx1 = 0, 0 + if self.hor_conv is not None: + kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) + else: + kernel_1xn, bias_1xn = 0, 0 + kernel_id, bias_id = self._identity_to_conv(self.rbr_identity) + kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id + bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id + return kernel_mxn, bias_mxn + + def _pad_to_mxn_tensor(self, kernel): + kernel_height, kernel_width = self.kernel_size + height, width = kernel.shape[2:] + pad_left_right = (kernel_width - width) // 2 + pad_top_down = (kernel_height - height) // 2 + return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down]) + + def switch_to_deploy(self): + if hasattr(self, "fused_conv"): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.fused_conv = nn.Conv2d( + in_channels=self.main_conv.in_channels, + out_channels=self.main_conv.out_channels, + kernel_size=self.main_conv.kernel_size, + stride=self.main_conv.stride, + padding=self.main_conv.padding, + dilation=self.main_conv.dilation, + groups=self.main_conv.groups, + bias=True, + ) + self.fused_conv.weight.data = kernel + self.fused_conv.bias.data = bias + self.deploy = True + for para in self.parameters(): + para.detach_() + for attr in ["main_conv", "main_bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]: + if hasattr(self, attr): + self.__delattr__(attr) + + if hasattr(self, "rbr_identity"): + self.__delattr__("rbr_identity") + + def switch_to_test(self): + kernel, bias = self.get_equivalent_kernel_bias() + self.fused_conv = nn.Conv2d( + out_channels=self.main_conv.out_channels, + kernel_size=self.main_conv.kernel_size, + stride=self.main_conv.stride, + padding=self.main_conv.padding, + dilation=self.main_conv.dilation, + groups=self.main_conv.groups, + bias=True, + ) + self.fused_conv.weight.data = kernel + self.fused_conv.bias.data = bias + for para in self.fused_conv.parameters(): + para.detach_() + self.deploy = True + + def switch_to_train(self): + if hasattr(self, "fused_conv"): + self.__delattr__("fused_conv") + self.deploy = False + + @property + def module_str(self): + return "Rep_%dx%d" % (self.kernel_size[0], self.kernel_size[1]) + + @property + def config(self): + return { + "name": RepConvLayer.__name__, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "kernel_size": self.kernel_size, + "stride": self.stride, + "dilation": self.dilation, + "groups": self.groups, + } + + @staticmethod + def build_from_config(config): + return RepConvLayer(**config) + + +class My2DLayer(nn.Module): + def __init__( + self, in_channels, out_channels, use_bn=True, act_func="relu", dropout_rate=0, ops_order="weight_bn_act" + ): + super(My2DLayer, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.use_bn = use_bn + self.act_func = act_func + self.dropout_rate = dropout_rate + self.ops_order = ops_order + + """ modules """ + modules = {} + # batch norm + if self.use_bn: + if self.bn_before_weight: + modules["bn"] = nn.BatchNorm2d(in_channels) + else: + modules["bn"] = nn.BatchNorm2d(out_channels) + else: + modules["bn"] = None + # activation + modules["act"] = build_activation(self.act_func, self.ops_list[0] != "act") + # dropout + if self.dropout_rate > 0: + modules["dropout"] = nn.Dropout2d(self.dropout_rate, inplace=True) + else: + modules["dropout"] = None + # weight + modules["weight"] = self.weight_op() + + # add modules + for op in self.ops_list: + if modules[op] is None: + continue + elif op == "weight": + if modules["dropout"] is not None: + self.add_module("dropout", modules["dropout"]) + for key in modules["weight"]: + self.add_module(key, modules["weight"][key]) + else: + self.add_module(op, modules[op]) + + @property + def ops_list(self): + return self.ops_order.split("_") + + @property + def bn_before_weight(self): + for op in self.ops_list: + if op == "bn": + return True + elif op == "weight": + return False + raise ValueError("Invalid ops_order: %s" % self.ops_order) + + def forward(self, x): + for module in self._modules.values(): + x = module(x) + return x + + @staticmethod + def is_zero_layer(): + return False + + +class ConvLayer(My2DLayer): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + has_shuffle=False, + use_bn=True, + act_func="relu", + dropout_rate=0, + ops_order="weight_bn_act", + ): + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + self.groups = groups + self.bias = bias + self.has_shuffle = has_shuffle + + super(ConvLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order) + + def weight_op(self): + padding = get_same_padding(self.kernel_size) + if isinstance(padding, int): + padding *= self.dilation + else: + padding[0] *= self.dilation + padding[1] *= self.dilation + + weight_dict = OrderedDict() + weight_dict["conv"] = nn.Conv2d( + self.in_channels, + self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=padding, + dilation=self.dilation, + groups=self.groups, + bias=self.bias, + ) + + return weight_dict diff --git a/doctr/models/modules/layers/tensorflow.py b/doctr/models/modules/layers/tensorflow.py new file mode 100644 index 000000000..1ce7bdb7c --- /dev/null +++ b/doctr/models/modules/layers/tensorflow.py @@ -0,0 +1,91 @@ +from typing import Any + +import tensorflow as tf +from tensorflow.keras import layers + +__all__ = ["RepConvLayer"] + + +class RepConvLayer(layers.Layer): + def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, groups=1): + super(RepConvLayer, self).__init__() + + kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size + + padding = (int(((kernel_size[0] - 1) * dilation) / 2), int(((kernel_size[1] - 1) * dilation) / 2)) + + self.activation = layers.ReLU() + self.main_conv = tf.keras.Sequential( + [ + layers.ZeroPadding2D(padding=padding), + layers.Conv2D( + filters=out_channels, + kernel_size=kernel_size, + strides=stride, + dilation_rate=dilation, + groups=groups, + use_bias=False, + input_shape=(None, None, in_channels), + ), + layers.BatchNormalization(), + ] + ) + + if kernel_size[1] != 1: + self.ver_conv = tf.keras.Sequential( + [ + layers.ZeroPadding2D(padding=(int(((kernel_size[0] - 1) * dilation) / 2), 0)), + layers.Conv2D( + filters=out_channels, + kernel_size=(kernel_size[0], 1), + strides=stride, + dilation_rate=(dilation, 1), + groups=groups, + use_bias=False, + input_shape=(None, None, in_channels), + ), + layers.BatchNormalization(), + ] + ) + + else: + self.ver_conv = None + + if kernel_size[0] != 1: + self.hor_conv = tf.keras.Sequential( + [ + layers.ZeroPadding2D(padding=(0, int(((kernel_size[1] - 1) * dilation) / 2))), + layers.Conv2D( + filters=out_channels, + kernel_size=(1, kernel_size[1]), + strides=stride, + dilation_rate=dilation, + groups=groups, + use_bias=False, + input_shape=(None, None, in_channels), + ), + layers.BatchNormalization(), + ] + ) + else: + self.hor_conv = None + + # self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None + + self.layers = [i for i in [self.main_conv, self.ver_conv, self.hor_conv, self.activation] if i is not None] + + def call( + self, + x: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + main_outputs = self.main_conv(x, **kwargs) + vertical_outputs = self.ver_conv(x, **kwargs) if self.ver_conv is not None else 0 + horizontal_outputs = self.hor_conv(x, **kwargs) if self.hor_conv is not None else 0 + # id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None else 0 + + p = main_outputs + vertical_outputs + q = horizontal_outputs # + id_out + r = p + q + + return self.activation(r) diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index b24030ca1..1e2b59ac1 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -150,3 +150,93 @@ def export_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.T ) logging.info(f"Model exported to {model_name}.onnx") return f"{model_name}.onnx" + + +def rep_model_convert(model: torch.nn.Module): + for module in model.modules(): + if hasattr(module, "switch_to_test"): + module.switch_to_test() # type: ignore + return model + + +def rep_model_unconvert(model: torch.nn.Module): + for module in model.modules(): + if hasattr(module, "switch_to_train"): + module.switch_to_train() # type: ignore + return model + + +def rep_model_convert_deploy(model: torch.nn.Module): + for module in model.modules(): + if hasattr(module, "switch_to_deploy"): + module.switch_to_deploy() # type: ignore + return model + + +def fuse_conv_bn(conv, bn): + """During inference, the functionary of batch norm layers is turned off but + only the mean and var alone channels are used, which exposes the chance to + fuse it with the preceding conv layers to save computations and simplify + network structures.""" + conv_w = conv.weight + conv_b = conv.bias if conv.bias is not None else torch.zeros_like(bn.running_mean) + + factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) + conv.old_weight = conv.weight + conv.old_biais = conv.bias + conv.weight = nn.Parameter(conv_w * factor.reshape([conv.out_channels, 1, 1, 1])) + conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) + + return conv + + +def fuse_module(m): + last_conv = None + last_conv_name = None + + for name, child in m.named_children(): + if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): + if last_conv is None: # only fuse BN that is after Conv + continue + fused_conv = fuse_conv_bn(last_conv, child) + m._modules[last_conv_name] = fused_conv + m._modules[name] = nn.Identity() + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + fuse_module(child) + return m + + +def unfuse_conv_bn(conv, bn): + """During inference, the functionary of batch norm layers is turned off but + only the mean and var alone channels are used, which exposes the chance to + fuse it with the preceding conv layers to save computations and simplify + network structures.""" + conv.weight = conv.old_weight + conv.bias = conv.old_biais + + return conv + + +def unfuse_module(m): + last_conv = None + last_conv_name = None + + for name, child in m.named_children(): + if isinstance(child, (nn.Identity, nn.Identity)): + if last_conv is None: # only fuse BN that is after Conv + continue + unfused_conv = unfuse_conv_bn(last_conv, child) + m._modules[last_conv_name] = unfused_conv + # To reduce changes, set BN as Identity instead of deleting it. + m._modules[name] = nn.BatchNorm2d(unfused_conv.out_channels) + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + unfuse_module(child) + return m diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index 8490c09f1..51f36a224 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -8,10 +8,12 @@ from typing import Any, Callable, List, Optional, Tuple, Union from zipfile import ZipFile +import numpy as np import tensorflow as tf import tf2onnx from tensorflow.keras import Model, layers +from doctr.models.modules.layers.tensorflow import RepConvLayer from doctr.utils.data import download_from_url logging.getLogger("tensorflow").setLevel(logging.DEBUG) @@ -70,8 +72,7 @@ def conv_sequence( ) -> List[layers.Layer]: """Builds a convolutional-based layer sequence - >>> from tensorflow.keras import Sequential - >>> from doctr.models import conv_sequence + >>> from doctr.models.utils import conv_sequence >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) Args: @@ -160,3 +161,115 @@ def export_model_to_onnx( logging.info(f"Model exported to {model_name}.zip") return f"{model_name}.onnx", output + + +def rep_model_convert(model): + for layer in model.layers: + if hasattr(layer, "switch_to_test"): + layer.switch_to_test() + return model + + +def rep_model_unconvert(model): + for layer in model.layers: + if hasattr(layer, "switch_to_train"): + layer.switch_to_train() + return model + + +def rep_model_convert_deploy(model): + for layer in model.layers: + if hasattr(layer, "switch_to_deploy"): + layer.switch_to_deploy() + return model + + +def fuse_conv_bn(conv, bn): + """During inference, the functionality of batch norm layers is turned off but + only the mean and variance along channels are used, which exposes the opportunity + to fuse it with the preceding conv layers to save computations and simplify + network structures.""" + + bn_weights, bn_biases, bn_running_mean, bn_running_var = bn.get_weights() + weights = conv.get_weights() + if len(weights) == 1: + conv_weights = weights[0] + conv_biases = np.zeros_like(bn_running_mean) + else: + conv_weights, conv_biases = conv.get_weights() + epsilon = bn.epsilon + scale_factor = bn_weights / np.sqrt(bn_running_var + epsilon) + + # Reshape the scale factor to match the convolutional weights shape + scale_factor = scale_factor.reshape((1, 1, 1, -1)) + + # Update convolutional weights and biases + fused_conv_weights = conv_weights * scale_factor + fused_conv_biases = (conv_biases - bn_running_mean) * scale_factor.flatten() + bn_biases + + conv.use_bias = True + conv.build(input_shape=conv.input_shape) + conv.set_weights([fused_conv_weights, fused_conv_biases]) + conv.old_weight, conv.old_biais = conv_weights, conv_biases + return conv + + +def fuse_module(model): + last_conv = None + for i, layer in enumerate(model.layers): + if isinstance(layer, (tf.keras.layers.BatchNormalization, tf.keras.layers.experimental.SyncBatchNormalization)): + if last_conv is None: # only fuse BN that is after Conv + continue + fuse_conv = fuse_conv_bn(last_conv, layer) + new_layer = tf.keras.layers.Lambda(lambda x: x) + model.layers[i] = new_layer + + setattr(layer, layer.name, new_layer) + print(last_conv.name) + print(fuse_conv.name) + print(layer.name) + print(new_layer.name) + print(model.layers[i].name) + print(model.layers[i]) + print() + elif isinstance(layer, tf.keras.layers.Conv2D): + last_conv = layer + elif isinstance(layer, (tf.keras.Sequential, RepConvLayer)): + fuse_module(layer) + return model + + +def unfuse_conv_bn(conv, bn): + """During inference, the functionary of batch norm layers is turned off but + only the mean and var alone channels are used, which exposes the chance to + fuse it with the preceding conv layers to save computations and simplify + network structures.""" + conv.set_weights([conv.old_weight, conv.old_biais]) + return conv + + +def unfuse_module(model): + last_conv = None + + for i, layer in enumerate(model.layers): + if isinstance(layer, tf.keras.layers.Layer): + pass + else: + continue + + if isinstance(layer, tf.keras.layers.Lambda): + if last_conv is None: + continue + unfused_conv, unfused_bn = unfuse_conv_bn(last_conv, layer) + + # In TensorFlow, we can't modify the model in-place like in PyTorch, + # so you would need to create a new model with the modified layers. + # Here, you'd replace the last_conv layer with unfused_conv and + # the current layer with unfused_bn. + + elif isinstance(layer, tf.keras.layers.Conv2D): + last_conv = layer + else: + # Recursive call for potentially nested layers (e.g., in case of a nested model) + unfuse_module(layer) + return layer diff --git a/inference.py b/inference.py new file mode 100644 index 000000000..e23a83648 --- /dev/null +++ b/inference.py @@ -0,0 +1,54 @@ +# git clone https://github.com/mindee/doctr.git +# pip install -e doctr/.[tf] +# conda install -y -c conda-forge weasyprint + +import json +import os + +import tensorflow as tf + +from doctr.io import DocumentFile +from doctr.models import ocr_predictor + +os.environ["USE_TF"] = "1" +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +gpu_devices = tf.config.experimental.list_physical_devices("GPU") +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + + +def main(args): + # Load docTR model + model = ocr_predictor(det_arch=args.arch_detection, reco_arch=args.arch_recognition, pretrained=True) + + # load image input file + single_img_doc = DocumentFile.from_images(args.input_file) + + # inference + output = model(single_img_doc) + + with open(args.output_file, "w") as f: + json.dump(output.export(), f) + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser( + description="docTR inference image script(TensorFlow)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("--arch_recognition", type=str, help="text-detection model") + parser.add_argument("--arch_detection", type=str, help="text-recognition model") + parser.add_argument("--input_file", type=str, help="path of image file") + parser.add_argument("--output_file", type=str, help="path of output file") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index 0ea879097..9184b9d25 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -44,6 +44,9 @@ def _test_classification(model, input_shape, output_size, batch_size=2): ["vit_b", (3, 32, 32), (126,)], # Check that the interpolation of positional embeddings for vit models works correctly ["vit_s", (3, 64, 64), (126,)], + ["textnetfast_tiny", (3, 32, 32), (126,)], + ["textnetfast_small", (3, 32, 32), (126,)], + ["textnetfast_base", (3, 32, 32), (126,)], ], ) def test_classification_architectures(arch_name, input_shape, output_size): @@ -125,6 +128,9 @@ def test_crop_orientation_model(mock_text_box): ["mobilenet_v3_large", (3, 32, 32), (126,)], ["mobilenet_v3_small_orientation", (3, 128, 128), (4,)], ["vit_b", (3, 32, 32), (126,)], + ["textnetfast_tiny", (3, 32, 32), (126,)], + ["textnetfast_small", (3, 32, 32), (126,)], + ["textnetfast_base", (3, 32, 32), (126,)], ], ) def test_models_onnx_export(arch_name, input_shape, output_size): diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 39eae6516..d2f4bb888 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -23,6 +23,10 @@ ["linknet_resnet18", (3, 512, 512), (1, 512, 512), True], ["linknet_resnet34", (3, 512, 512), (1, 512, 512), True], ["linknet_resnet50", (3, 512, 512), (1, 512, 512), True], + ["fast_tiny", (3, 512, 512), (1, 512, 512), True], + ["fast_small", (3, 512, 512), (1, 512, 512), True], + ["fast_base", (3, 512, 512), (1, 512, 512), True], + ], ) def test_detection_models(arch_name, input_shape, output_size, out_prob, train_mode): @@ -125,6 +129,9 @@ def test_dilate(): ["linknet_resnet18", (3, 512, 512), (1, 512, 512)], ["linknet_resnet34", (3, 512, 512), (1, 512, 512)], ["linknet_resnet50", (3, 512, 512), (1, 512, 512)], + ["fast_tiny", (3, 512, 512), (1, 512, 512), True], + ["fast_small", (3, 512, 512), (1, 512, 512), True], + ["fast_base", (3, 512, 512), (1, 512, 512), True], ], ) def test_models_onnx_export(arch_name, input_shape, output_size): diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 25d3ca5ad..e38386ba4 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -29,6 +29,9 @@ ["mobilenet_v3_large", (32, 32, 3), (126,)], ["vit_s", (32, 32, 3), (126,)], ["vit_b", (32, 32, 3), (126,)], + ["textnetfast_tiny", (32, 32, 3), (126,)], + ["textnetfast_small", (32, 32, 3), (126,)], + ["textnetfast_base", (32, 32, 3), (126,)], ], ) def test_classification_architectures(arch_name, input_shape, output_size): @@ -136,6 +139,21 @@ def test_crop_orientation_model(mock_text_box): (126,), marks=pytest.mark.skipif(system_available_memory < 16, reason="to less memory"), ), + [ + "textnetfast_tiny", + (32, 32, 3), + (126,), + ], + [ + "textnetfast_small", + (32, 32, 3), + (126,), + ], + [ + "textnetfast_base", + (32, 32, 3), + (126,), + ], ], ) def test_models_onnx_export(arch_name, input_shape, output_size):