Skip to content

Commit

Permalink
[orientation] Part 1: Add page orientation predictor (mindee#1566)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Apr 26, 2024
1 parent 5568612 commit bf6c34e
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 38 deletions.
4 changes: 4 additions & 0 deletions docs/source/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.mobilenet_v3_small_crop_orientation

.. autofunction:: doctr.models.classification.mobilenet_v3_small_page_orientation

.. autofunction:: doctr.models.classification.magc_resnet31

.. autofunction:: doctr.models.classification.vit_s
Expand All @@ -41,6 +43,8 @@ doctr.models.classification

.. autofunction:: doctr.models.classification.crop_orientation_predictor

.. autofunction:: doctr.models.classification.page_orientation_predictor


doctr.models.detection
----------------------
Expand Down
26 changes: 26 additions & 0 deletions doctr/models/classification/mobilenet/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_crop_orientation",
"mobilenet_v3_small_page_orientation",
]

default_cfgs: Dict[str, Dict[str, Any]] = {
Expand Down Expand Up @@ -245,3 +246,28 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)


def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import torch
>>> from doctr.models import mobilenet_v3_small_page_orientation
>>> model = mobilenet_v3_small_page_orientation(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
----
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small_page_orientation",
pretrained,
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)
21 changes: 21 additions & 0 deletions doctr/models/classification/mobilenet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_crop_orientation",
"mobilenet_v3_small_page_orientation",
]


Expand Down Expand Up @@ -414,3 +415,23 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any)
a keras.Model
"""
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs)


def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import tensorflow as tf
>>> from doctr.models import mobilenet_v3_small_page_orientation
>>> model = mobilenet_v3_small_page_orientation(pretrained=False)
>>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32)
>>> out = model(input_tensor)
Args:
----
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
a keras.Model
"""
return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs)
12 changes: 6 additions & 6 deletions doctr/models/classification/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


class OrientationPredictor(nn.Module):
"""Implements an object able to detect the reading direction of a text box.
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
"""Implements an object able to detect the reading direction of a text box or a page.
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
Args:
----
Expand All @@ -37,13 +37,13 @@ def __init__(
@torch.inference_mode()
def forward(
self,
crops: List[Union[np.ndarray, torch.Tensor]],
inputs: List[Union[np.ndarray, torch.Tensor]],
) -> List[Union[List[int], List[float]]]:
# Dimension check
if any(crop.ndim != 3 for crop in crops):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
if any(input.ndim != 3 for input in inputs):
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")

processed_batches = self.pre_processor(crops)
processed_batches = self.pre_processor(inputs)
_params = next(self.model.parameters())
self.model, processed_batches = set_device_and_dtype(
self.model, processed_batches, _params.device, _params.dtype
Expand Down
12 changes: 6 additions & 6 deletions doctr/models/classification/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


class OrientationPredictor(NestedObject):
"""Implements an object able to detect the reading direction of a text box.
4 possible orientations: 0, 90, 180, 270 degrees counter clockwise.
"""Implements an object able to detect the reading direction of a text box or a page.
4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise.
Args:
----
Expand All @@ -37,13 +37,13 @@ def __init__(

def __call__(
self,
crops: List[Union[np.ndarray, tf.Tensor]],
inputs: List[Union[np.ndarray, tf.Tensor]],
) -> List[Union[List[int], List[float]]]:
# Dimension check
if any(crop.ndim != 3 for crop in crops):
raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.")
if any(input.ndim != 3 for input in inputs):
raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.")

processed_batches = self.pre_processor(crops)
processed_batches = self.pre_processor(inputs)
predicted_batches = [self.model(batch, training=False) for batch in processed_batches]

# confidence
Expand Down
38 changes: 31 additions & 7 deletions doctr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..preprocessor import PreProcessor
from .predictor import OrientationPredictor

__all__ = ["crop_orientation_predictor"]
__all__ = ["crop_orientation_predictor", "page_orientation_predictor"]

ARCHS: List[str] = [
"magc_resnet31",
Expand All @@ -31,7 +31,7 @@
"vit_s",
"vit_b",
]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation"]
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]


def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor:
Expand All @@ -42,7 +42,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
_model = classification.__dict__[arch](pretrained=pretrained)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128)
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]
predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
Expand All @@ -53,17 +53,41 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
def crop_orientation_predictor(
arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Orientation classification architecture.
"""Crop orientation classification architecture.
>>> import numpy as np
>>> from doctr.models import crop_orientation_predictor
>>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True)
>>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
>>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True)
>>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8)
>>> out = model([input_crop])
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small')
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, pretrained, **kwargs)


def page_orientation_predictor(
arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Page orientation classification architecture.
>>> import numpy as np
>>> from doctr.models import page_orientation_predictor
>>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True)
>>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8)
>>> out = model([input_page])
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
pretrained: If True, returns a model pre-trained on our recognition crops dataset
**kwargs: keyword arguments to be passed to the OrientationPredictor
Expand Down
46 changes: 36 additions & 10 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_classification_architectures(arch_name, input_shape, output_size):
@pytest.mark.parametrize(
"arch_name, input_shape",
[
["mobilenet_v3_small_crop_orientation", (3, 128, 128)],
["mobilenet_v3_small_crop_orientation", (3, 256, 256)],
["mobilenet_v3_small_page_orientation", (3, 512, 512)],
],
)
def test_classification_models(arch_name, input_shape):
Expand All @@ -81,19 +82,30 @@ def test_classification_models(arch_name, input_shape):
"arch_name",
[
"mobilenet_v3_small_crop_orientation",
"mobilenet_v3_small_page_orientation",
],
)
def test_classification_zoo(arch_name):
batch_size = 16
# Model
predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False)
predictor.model.eval()

with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False)
if "crop" in arch_name:
batch_size = 16
input_tensor = torch.rand((batch_size, 3, 256, 256))
# Model
predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False)
predictor.model.eval()

with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False)
else:
batch_size = 2
input_tensor = torch.rand((batch_size, 3, 512, 512))
# Model
predictor = classification.zoo.page_orientation_predictor(arch_name, pretrained=False)
predictor.model.eval()

with pytest.raises(ValueError):
predictor = classification.zoo.page_orientation_predictor(arch="wrong_model", pretrained=False)
# object check
assert isinstance(predictor, OrientationPredictor)
input_tensor = torch.rand((batch_size, 3, 128, 128))
if torch.cuda.is_available():
predictor.model.cuda()
input_tensor = input_tensor.cuda()
Expand Down Expand Up @@ -123,6 +135,19 @@ def test_crop_orientation_model(mock_text_box):
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])


def test_page_orientation_model(mock_payslip):
text_box_0 = cv2.imread(mock_payslip)
# rotates counter-clockwise
text_box_270 = np.rot90(text_box_0, 1)
text_box_180 = np.rot90(text_box_0, 2)
text_box_90 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True)
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3]
# 270 degrees is equivalent to -90 degrees
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])


@pytest.mark.parametrize(
"arch_name, input_shape, output_size",
[
Expand All @@ -135,7 +160,8 @@ def test_crop_orientation_model(mock_text_box):
["magc_resnet31", (3, 32, 32), (126,)],
["mobilenet_v3_small", (3, 32, 32), (126,)],
["mobilenet_v3_large", (3, 32, 32), (126,)],
["mobilenet_v3_small_crop_orientation", (3, 128, 128), (4,)],
["mobilenet_v3_small_crop_orientation", (3, 256, 256), (4,)],
["mobilenet_v3_small_page_orientation", (3, 512, 512), (4,)],
["vit_s", (3, 32, 32), (126,)],
["vit_b", (3, 32, 32), (126,)],
["textnet_tiny", (3, 32, 32), (126,)],
Expand Down
47 changes: 38 additions & 9 deletions tests/tensorflow/test_models_classification_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def test_classification_architectures(arch_name, input_shape, output_size):
@pytest.mark.parametrize(
"arch_name, input_shape",
[
["mobilenet_v3_small_crop_orientation", (128, 128, 3)],
["mobilenet_v3_small_crop_orientation", (256, 256, 3)],
["mobilenet_v3_small_page_orientation", (512, 512, 3)],
],
)
def test_classification_models(arch_name, input_shape):
Expand All @@ -68,17 +69,28 @@ def test_classification_models(arch_name, input_shape):
"arch_name",
[
"mobilenet_v3_small_crop_orientation",
"mobilenet_v3_small_page_orientation",
],
)
def test_classification_zoo(arch_name):
batch_size = 16
# Model
predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False)
with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False)
if "crop" in arch_name:
batch_size = 16
input_tensor = tf.random.uniform(shape=[batch_size, 256, 256, 3], minval=0, maxval=1)
# Model
predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False)

with pytest.raises(ValueError):
predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False)
else:
batch_size = 2
input_tensor = tf.random.uniform(shape=[batch_size, 512, 512, 3], minval=0, maxval=1)
# Model
predictor = classification.zoo.page_orientation_predictor(arch_name, pretrained=False)

with pytest.raises(ValueError):
predictor = classification.zoo.page_orientation_predictor(arch="wrong_model", pretrained=False)
# object check
assert isinstance(predictor, OrientationPredictor)
input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1)
out = predictor(input_tensor)
class_idxs, classes, confs = out[0], out[1], out[2]
assert isinstance(class_idxs, list) and len(class_idxs) == batch_size
Expand All @@ -102,6 +114,22 @@ def test_crop_orientation_model(mock_text_box):
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])


# TODO: uncomment when model is available
"""
def test_page_orientation_model(mock_payslip):
text_box_0 = cv2.imread(mock_payslip)
# rotates counter-clockwise
text_box_270 = np.rot90(text_box_0, 1)
text_box_180 = np.rot90(text_box_0, 2)
text_box_90 = np.rot90(text_box_0, 3)
classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True)
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3]
# 270 degrees is equivalent to -90 degrees
assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90]
assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2])
"""


# temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue)
# ref.: https://github.com/mindee/doctr/pull/1201
@pytest.mark.parametrize(
Expand All @@ -110,7 +138,8 @@ def test_crop_orientation_model(mock_text_box):
["vgg16_bn_r", (32, 32, 3), (126,)],
["mobilenet_v3_small", (512, 512, 3), (126,)],
["mobilenet_v3_large", (512, 512, 3), (126,)],
["mobilenet_v3_small_crop_orientation", (128, 128, 3), (4,)],
["mobilenet_v3_small_crop_orientation", (256, 256, 3), (4,)],
["mobilenet_v3_small_page_orientation", (512, 512, 3), (4,)],
["resnet18", (32, 32, 3), (126,)],
["vit_s", (32, 32, 3), (126,)],
["textnet_tiny", (32, 32, 3), (126,)],
Expand Down Expand Up @@ -163,7 +192,7 @@ def test_models_onnx_export(arch_name, input_shape, output_size):
# Model
batch_size = 2
tf.keras.backend.clear_session()
if arch_name == "mobilenet_v3_small_crop_orientation":
if "orientation" in arch_name:
model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape)
else:
model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape)
Expand Down

0 comments on commit bf6c34e

Please sign in to comment.