diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index 0424e9c3d..2baf095ee 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -27,6 +27,8 @@ doctr.models.classification .. autofunction:: doctr.models.classification.mobilenet_v3_small_crop_orientation +.. autofunction:: doctr.models.classification.mobilenet_v3_small_page_orientation + .. autofunction:: doctr.models.classification.magc_resnet31 .. autofunction:: doctr.models.classification.vit_s @@ -41,6 +43,8 @@ doctr.models.classification .. autofunction:: doctr.models.classification.crop_orientation_predictor +.. autofunction:: doctr.models.classification.page_orientation_predictor + doctr.models.detection ---------------------- diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py index a8b66e1a1..4c1db5b6c 100644 --- a/doctr/models/classification/mobilenet/pytorch.py +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -20,6 +20,7 @@ "mobilenet_v3_large", "mobilenet_v3_large_r", "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", ] default_cfgs: Dict[str, Dict[str, Any]] = { @@ -245,3 +246,28 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, ) + + +def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + >>> import torch + >>> from doctr.models import mobilenet_v3_small_page_orientation + >>> model = mobilenet_v3_small_page_orientation(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) + >>> out = model(input_tensor) + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + Returns: + ------- + a torch.nn.Module + """ + return _mobilenet_v3( + "mobilenet_v3_small_page_orientation", + pretrained, + ignore_keys=["classifier.3.weight", "classifier.3.bias"], + **kwargs, + ) diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index 92adc887b..3e0b99a9e 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -22,6 +22,7 @@ "mobilenet_v3_large", "mobilenet_v3_large_r", "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", ] @@ -414,3 +415,23 @@ def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) a keras.Model """ return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs) + + +def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_page_orientation + >>> model = mobilenet_v3_small_page_orientation(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + Args: + ---- + pretrained: boolean, True if model is pretrained + **kwargs: keyword arguments of the MobileNetV3 architecture + Returns: + ------- + a keras.Model + """ + return _mobilenet_v3("mobilenet_v3_small_page_orientation", pretrained, include_top=True, **kwargs) diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index 167d2af8d..d06125056 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -16,8 +16,8 @@ class OrientationPredictor(nn.Module): - """Implements an object able to detect the reading direction of a text box. - 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. + """Implements an object able to detect the reading direction of a text box or a page. + 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. Args: ---- @@ -37,13 +37,13 @@ def __init__( @torch.inference_mode() def forward( self, - crops: List[Union[np.ndarray, torch.Tensor]], + inputs: List[Union[np.ndarray, torch.Tensor]], ) -> List[Union[List[int], List[float]]]: # Dimension check - if any(crop.ndim != 3 for crop in crops): - raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + if any(input.ndim != 3 for input in inputs): + raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") - processed_batches = self.pre_processor(crops) + processed_batches = self.pre_processor(inputs) _params = next(self.model.parameters()) self.model, processed_batches = set_device_and_dtype( self.model, processed_batches, _params.device, _params.dtype diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index 1eb6894ac..95295584f 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -16,8 +16,8 @@ class OrientationPredictor(NestedObject): - """Implements an object able to detect the reading direction of a text box. - 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. + """Implements an object able to detect the reading direction of a text box or a page. + 4 possible orientations: 0, 90, 180, 270 (-90) degrees counter clockwise. Args: ---- @@ -37,13 +37,13 @@ def __init__( def __call__( self, - crops: List[Union[np.ndarray, tf.Tensor]], + inputs: List[Union[np.ndarray, tf.Tensor]], ) -> List[Union[List[int], List[float]]]: # Dimension check - if any(crop.ndim != 3 for crop in crops): - raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + if any(input.ndim != 3 for input in inputs): + raise ValueError("incorrect input shape: all inputs are expected to be multi-channel 2D images.") - processed_batches = self.pre_processor(crops) + processed_batches = self.pre_processor(inputs) predicted_batches = [self.model(batch, training=False) for batch in processed_batches] # confidence diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index ae736d773..9368bb225 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -11,7 +11,7 @@ from ..preprocessor import PreProcessor from .predictor import OrientationPredictor -__all__ = ["crop_orientation_predictor"] +__all__ = ["crop_orientation_predictor", "page_orientation_predictor"] ARCHS: List[str] = [ "magc_resnet31", @@ -31,7 +31,7 @@ "vit_s", "vit_b", ] -ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation"] +ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"] def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor: @@ -42,7 +42,7 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient _model = classification.__dict__[arch](pretrained=pretrained) kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) - kwargs["batch_size"] = kwargs.get("batch_size", 128) + kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model @@ -53,17 +53,41 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient def crop_orientation_predictor( arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any ) -> OrientationPredictor: - """Orientation classification architecture. + """Crop orientation classification architecture. >>> import numpy as np >>> from doctr.models import crop_orientation_predictor - >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True) - >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True) + >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8) >>> out = model([input_crop]) Args: ---- - arch: name of the architecture to use (e.g. 'mobilenet_v3_small') + arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') + pretrained: If True, returns a model pre-trained on our recognition crops dataset + **kwargs: keyword arguments to be passed to the OrientationPredictor + + Returns: + ------- + OrientationPredictor + """ + return _orientation_predictor(arch, pretrained, **kwargs) + + +def page_orientation_predictor( + arch: str = "mobilenet_v3_small_page_orientation", pretrained: bool = False, **kwargs: Any +) -> OrientationPredictor: + """Page orientation classification architecture. + + >>> import numpy as np + >>> from doctr.models import page_orientation_predictor + >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True) + >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + ---- + arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset **kwargs: keyword arguments to be passed to the OrientationPredictor diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index 0e29638ff..5f1270883 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -60,7 +60,8 @@ def test_classification_architectures(arch_name, input_shape, output_size): @pytest.mark.parametrize( "arch_name, input_shape", [ - ["mobilenet_v3_small_crop_orientation", (3, 128, 128)], + ["mobilenet_v3_small_crop_orientation", (3, 256, 256)], + ["mobilenet_v3_small_page_orientation", (3, 512, 512)], ], ) def test_classification_models(arch_name, input_shape): @@ -81,19 +82,30 @@ def test_classification_models(arch_name, input_shape): "arch_name", [ "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", ], ) def test_classification_zoo(arch_name): - batch_size = 16 - # Model - predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) - predictor.model.eval() - - with pytest.raises(ValueError): - predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) + if "crop" in arch_name: + batch_size = 16 + input_tensor = torch.rand((batch_size, 3, 256, 256)) + # Model + predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) + predictor.model.eval() + + with pytest.raises(ValueError): + predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) + else: + batch_size = 2 + input_tensor = torch.rand((batch_size, 3, 512, 512)) + # Model + predictor = classification.zoo.page_orientation_predictor(arch_name, pretrained=False) + predictor.model.eval() + + with pytest.raises(ValueError): + predictor = classification.zoo.page_orientation_predictor(arch="wrong_model", pretrained=False) # object check assert isinstance(predictor, OrientationPredictor) - input_tensor = torch.rand((batch_size, 3, 128, 128)) if torch.cuda.is_available(): predictor.model.cuda() input_tensor = input_tensor.cuda() @@ -123,6 +135,19 @@ def test_crop_orientation_model(mock_text_box): assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) +def test_page_orientation_model(mock_payslip): + text_box_0 = cv2.imread(mock_payslip) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + # 270 degrees is equivalent to -90 degrees + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) + + @pytest.mark.parametrize( "arch_name, input_shape, output_size", [ @@ -135,7 +160,8 @@ def test_crop_orientation_model(mock_text_box): ["magc_resnet31", (3, 32, 32), (126,)], ["mobilenet_v3_small", (3, 32, 32), (126,)], ["mobilenet_v3_large", (3, 32, 32), (126,)], - ["mobilenet_v3_small_crop_orientation", (3, 128, 128), (4,)], + ["mobilenet_v3_small_crop_orientation", (3, 256, 256), (4,)], + ["mobilenet_v3_small_page_orientation", (3, 512, 512), (4,)], ["vit_s", (3, 32, 32), (126,)], ["vit_b", (3, 32, 32), (126,)], ["textnet_tiny", (3, 32, 32), (126,)], diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index dd9223073..23fb019c0 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -50,7 +50,8 @@ def test_classification_architectures(arch_name, input_shape, output_size): @pytest.mark.parametrize( "arch_name, input_shape", [ - ["mobilenet_v3_small_crop_orientation", (128, 128, 3)], + ["mobilenet_v3_small_crop_orientation", (256, 256, 3)], + ["mobilenet_v3_small_page_orientation", (512, 512, 3)], ], ) def test_classification_models(arch_name, input_shape): @@ -68,17 +69,28 @@ def test_classification_models(arch_name, input_shape): "arch_name", [ "mobilenet_v3_small_crop_orientation", + "mobilenet_v3_small_page_orientation", ], ) def test_classification_zoo(arch_name): - batch_size = 16 - # Model - predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) - with pytest.raises(ValueError): - predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) + if "crop" in arch_name: + batch_size = 16 + input_tensor = tf.random.uniform(shape=[batch_size, 256, 256, 3], minval=0, maxval=1) + # Model + predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) + + with pytest.raises(ValueError): + predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) + else: + batch_size = 2 + input_tensor = tf.random.uniform(shape=[batch_size, 512, 512, 3], minval=0, maxval=1) + # Model + predictor = classification.zoo.page_orientation_predictor(arch_name, pretrained=False) + + with pytest.raises(ValueError): + predictor = classification.zoo.page_orientation_predictor(arch="wrong_model", pretrained=False) # object check assert isinstance(predictor, OrientationPredictor) - input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1) out = predictor(input_tensor) class_idxs, classes, confs = out[0], out[1], out[2] assert isinstance(class_idxs, list) and len(class_idxs) == batch_size @@ -102,6 +114,22 @@ def test_crop_orientation_model(mock_text_box): assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) +# TODO: uncomment when model is available +""" +def test_page_orientation_model(mock_payslip): + text_box_0 = cv2.imread(mock_payslip) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_page_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + # 270 degrees is equivalent to -90 degrees + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) +""" + + # temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue) # ref.: https://github.com/mindee/doctr/pull/1201 @pytest.mark.parametrize( @@ -110,7 +138,8 @@ def test_crop_orientation_model(mock_text_box): ["vgg16_bn_r", (32, 32, 3), (126,)], ["mobilenet_v3_small", (512, 512, 3), (126,)], ["mobilenet_v3_large", (512, 512, 3), (126,)], - ["mobilenet_v3_small_crop_orientation", (128, 128, 3), (4,)], + ["mobilenet_v3_small_crop_orientation", (256, 256, 3), (4,)], + ["mobilenet_v3_small_page_orientation", (512, 512, 3), (4,)], ["resnet18", (32, 32, 3), (126,)], ["vit_s", (32, 32, 3), (126,)], ["textnet_tiny", (32, 32, 3), (126,)], @@ -163,7 +192,7 @@ def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 tf.keras.backend.clear_session() - if arch_name == "mobilenet_v3_small_crop_orientation": + if "orientation" in arch_name: model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) else: model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape)