From 028571104a81bf4f5bbf48c655ecc2f9ef8a3c2a Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Fri, 19 Apr 2024 18:24:09 +0200 Subject: [PATCH] fix orient train and modify for page (#1559) --- docs/source/modules/models.rst | 2 +- .../classification/mobilenet/pytorch.py | 21 ++++++++++------ .../classification/mobilenet/tensorflow.py | 21 ++++++++++------ .../classification/predictor/pytorch.py | 10 +++----- .../classification/predictor/tensorflow.py | 10 +++----- doctr/models/classification/zoo.py | 18 ++++++------- doctr/models/kie_predictor/base.py | 4 +-- doctr/models/predictor/base.py | 4 +-- .../train_pytorch_orientation.py | 4 +-- .../train_tensorflow_orientation.py | 4 +-- .../pytorch/test_models_classification_pt.py | 23 +++++++++-------- .../test_models_classification_tf.py | 25 ++++++++++--------- 12 files changed, 77 insertions(+), 69 deletions(-) diff --git a/docs/source/modules/models.rst b/docs/source/modules/models.rst index 380464cced..0424e9c3dd 100644 --- a/docs/source/modules/models.rst +++ b/docs/source/modules/models.rst @@ -25,7 +25,7 @@ doctr.models.classification .. autofunction:: doctr.models.classification.mobilenet_v3_large_r -.. autofunction:: doctr.models.classification.mobilenet_v3_small_orientation +.. autofunction:: doctr.models.classification.mobilenet_v3_small_crop_orientation .. autofunction:: doctr.models.classification.magc_resnet31 diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py index dc3d18f59e..73b98d833f 100644 --- a/doctr/models/classification/mobilenet/pytorch.py +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -19,7 +19,7 @@ "mobilenet_v3_small_r", "mobilenet_v3_large", "mobilenet_v3_large_r", - "mobilenet_v3_small_orientation", + "mobilenet_v3_small_crop_orientation", ] default_cfgs: Dict[str, Dict[str, Any]] = { @@ -51,13 +51,20 @@ "classes": list(VOCABS["french"]), "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0", }, - "mobilenet_v3_small_orientation": { + "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (3, 128, 128), - "classes": [0, 90, 180, 270], + "classes": [0, -90, 180, 90], "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt&src=0", }, + "mobilenet_v3_small_page_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (3, 512, 512), + "classes": [0, -90, 180, 90], + "url": None, + }, } @@ -212,14 +219,14 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3 ) -def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: +def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: """MobileNetV3-Small architecture as described in `"Searching for MobileNetV3", `_. >>> import torch - >>> from doctr.models import mobilenet_v3_small_orientation - >>> model = mobilenet_v3_small_orientation(pretrained=False) + >>> from doctr.models import mobilenet_v3_small_crop_orientation + >>> model = mobilenet_v3_small_crop_orientation(pretrained=False) >>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32) >>> out = model(input_tensor) @@ -233,7 +240,7 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> m a torch.nn.Module """ return _mobilenet_v3( - "mobilenet_v3_small_orientation", + "mobilenet_v3_small_crop_orientation", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs, diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py index a12521865c..92adc887bf 100644 --- a/doctr/models/classification/mobilenet/tensorflow.py +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -21,7 +21,7 @@ "mobilenet_v3_small_r", "mobilenet_v3_large", "mobilenet_v3_large_r", - "mobilenet_v3_small_orientation", + "mobilenet_v3_small_crop_orientation", ] @@ -54,13 +54,20 @@ "classes": list(VOCABS["french"]), "url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-3d61452e.zip&src=0", }, - "mobilenet_v3_small_orientation": { + "mobilenet_v3_small_crop_orientation": { "mean": (0.694, 0.695, 0.693), "std": (0.299, 0.296, 0.301), "input_shape": (128, 128, 3), - "classes": [0, 90, 180, 270], + "classes": [0, -90, 180, 90], "url": "https://doctr-static.mindee.com/models?id=v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip&src=0", }, + "mobilenet_v3_small_page_orientation": { + "mean": (0.694, 0.695, 0.693), + "std": (0.299, 0.296, 0.301), + "input_shape": (512, 512, 3), + "classes": [0, -90, 180, 90], + "url": None, + }, } @@ -386,14 +393,14 @@ def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3 return _mobilenet_v3("mobilenet_v3_large_r", pretrained, True, **kwargs) -def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: +def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: """MobileNetV3-Small architecture as described in `"Searching for MobileNetV3", `_. >>> import tensorflow as tf - >>> from doctr.models import mobilenet_v3_small_orientation - >>> model = mobilenet_v3_small_orientation(pretrained=False) + >>> from doctr.models import mobilenet_v3_small_crop_orientation + >>> model = mobilenet_v3_small_crop_orientation(pretrained=False) >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) >>> out = model(input_tensor) @@ -406,4 +413,4 @@ def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> M ------- a keras.Model """ - return _mobilenet_v3("mobilenet_v3_small_orientation", pretrained, include_top=True, **kwargs) + return _mobilenet_v3("mobilenet_v3_small_crop_orientation", pretrained, include_top=True, **kwargs) diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py index f8bed39eb8..167d2af8d9 100644 --- a/doctr/models/classification/predictor/pytorch.py +++ b/doctr/models/classification/predictor/pytorch.py @@ -12,10 +12,10 @@ from doctr.models.preprocessor import PreProcessor from doctr.models.utils import set_device_and_dtype -__all__ = ["CropOrientationPredictor"] +__all__ = ["OrientationPredictor"] -class CropOrientationPredictor(nn.Module): +class OrientationPredictor(nn.Module): """Implements an object able to detect the reading direction of a text box. 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. @@ -57,11 +57,7 @@ def forward( predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches] class_idxs = [int(pred) for batch in predicted_batches for pred in batch] - # Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90 - classes = [ - int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90 - for idx in class_idxs - ] + classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] confs = [round(float(p), 2) for prob in probs for p in prob] return [class_idxs, classes, confs] diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py index 2299bacbd7..1eb6894ac6 100644 --- a/doctr/models/classification/predictor/tensorflow.py +++ b/doctr/models/classification/predictor/tensorflow.py @@ -12,10 +12,10 @@ from doctr.models.preprocessor import PreProcessor from doctr.utils.repr import NestedObject -__all__ = ["CropOrientationPredictor"] +__all__ = ["OrientationPredictor"] -class CropOrientationPredictor(NestedObject): +class OrientationPredictor(NestedObject): """Implements an object able to detect the reading direction of a text box. 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. @@ -52,11 +52,7 @@ def __call__( predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches] class_idxs = [int(pred) for batch in predicted_batches for pred in batch] - # Keep unified with page orientation range (counter clock rotation => negative) so 270 -> -90 - classes = [ - int(self.model.cfg["classes"][idx]) if int(self.model.cfg["classes"][idx]) != 270 else -90 - for idx in class_idxs - ] + classes = [int(self.model.cfg["classes"][idx]) for idx in class_idxs] confs = [round(float(p), 2) for prob in probs for p in prob] return [class_idxs, classes, confs] diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 6179ff976a..ae736d773e 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -9,7 +9,7 @@ from .. import classification from ..preprocessor import PreProcessor -from .predictor import CropOrientationPredictor +from .predictor import OrientationPredictor __all__ = ["crop_orientation_predictor"] @@ -31,10 +31,10 @@ "vit_s", "vit_b", ] -ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_orientation"] +ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation"] -def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> CropOrientationPredictor: +def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> OrientationPredictor: if arch not in ORIENTATION_ARCHS: raise ValueError(f"unknown architecture '{arch}'") @@ -44,15 +44,15 @@ def _crop_orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> C kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] - predictor = CropOrientationPredictor( + predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model ) return predictor def crop_orientation_predictor( - arch: str = "mobilenet_v3_small_orientation", pretrained: bool = False, **kwargs: Any -) -> CropOrientationPredictor: + arch: str = "mobilenet_v3_small_crop_orientation", pretrained: bool = False, **kwargs: Any +) -> OrientationPredictor: """Orientation classification architecture. >>> import numpy as np @@ -65,10 +65,10 @@ def crop_orientation_predictor( ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small') pretrained: If True, returns a model pre-trained on our recognition crops dataset - **kwargs: keyword arguments to be passed to the CropOrientationPredictor + **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- - CropOrientationPredictor + OrientationPredictor """ - return _crop_orientation_predictor(arch, pretrained, **kwargs) + return _orientation_predictor(arch, pretrained, **kwargs) diff --git a/doctr/models/kie_predictor/base.py b/doctr/models/kie_predictor/base.py index 107009bed4..63a87f5900 100644 --- a/doctr/models/kie_predictor/base.py +++ b/doctr/models/kie_predictor/base.py @@ -7,7 +7,7 @@ from doctr.models.builder import KIEDocumentBuilder -from ..classification.predictor import CropOrientationPredictor +from ..classification.predictor import OrientationPredictor from ..predictor.base import _OCRPredictor __all__ = ["_KIEPredictor"] @@ -28,7 +28,7 @@ class _KIEPredictor(_OCRPredictor): kwargs: keyword args of `DocumentBuilder` """ - crop_orientation_predictor: Optional[CropOrientationPredictor] + crop_orientation_predictor: Optional[OrientationPredictor] def __init__( self, diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index bc5bfcb5db..0033b2cbf4 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -12,7 +12,7 @@ from .._utils import rectify_crops, rectify_loc_preds from ..classification import crop_orientation_predictor -from ..classification.predictor import CropOrientationPredictor +from ..classification.predictor import OrientationPredictor __all__ = ["_OCRPredictor"] @@ -32,7 +32,7 @@ class _OCRPredictor: **kwargs: keyword args of `DocumentBuilder` """ - crop_orientation_predictor: Optional[CropOrientationPredictor] + crop_orientation_predictor: Optional[OrientationPredictor] def __init__( self, diff --git a/references/classification/train_pytorch_orientation.py b/references/classification/train_pytorch_orientation.py index 688e485644..e1d0df02f3 100644 --- a/references/classification/train_pytorch_orientation.py +++ b/references/classification/train_pytorch_orientation.py @@ -35,7 +35,7 @@ from doctr.models.utils import export_model_to_onnx from utils import EarlyStopper, plot_recorder, plot_samples -CLASSES = [0, 90, 180, 270] +CLASSES = [0, -90, 180, 90] def rnd_rotate(img: torch.Tensor, target): @@ -191,7 +191,7 @@ def main(args): torch.backends.cudnn.benchmark = True - input_size = (256, 256) if args.type == "page" else (32, 32) + input_size = (512, 512) if args.type == "page" else (256, 256) # Load val data generator st = time.time() diff --git a/references/classification/train_tensorflow_orientation.py b/references/classification/train_tensorflow_orientation.py index ed0479172c..8af14a5cf4 100644 --- a/references/classification/train_tensorflow_orientation.py +++ b/references/classification/train_tensorflow_orientation.py @@ -30,7 +30,7 @@ from doctr.transforms.functional import rotated_img_tensor from utils import EarlyStopper, plot_recorder, plot_samples -CLASSES = [0, 90, 180, 270] +CLASSES = [0, -90, 180, 90] def rnd_rotate(img: tf.Tensor, target): @@ -147,7 +147,7 @@ def main(args): if not isinstance(args.workers, int): args.workers = min(16, mp.cpu_count()) - input_size = (256, 256) if args.type == "page" else (32, 32) + input_size = (512, 512) if args.type == "page" else (256, 256) # AMP if args.amp: diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index 11fe0b7df1..0e29638ff3 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -8,7 +8,7 @@ import torch from doctr.models import classification -from doctr.models.classification.predictor import CropOrientationPredictor +from doctr.models.classification.predictor import OrientationPredictor from doctr.models.utils import export_model_to_onnx @@ -60,7 +60,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): @pytest.mark.parametrize( "arch_name, input_shape", [ - ["mobilenet_v3_small_orientation", (3, 128, 128)], + ["mobilenet_v3_small_crop_orientation", (3, 128, 128)], ], ) def test_classification_models(arch_name, input_shape): @@ -80,7 +80,7 @@ def test_classification_models(arch_name, input_shape): @pytest.mark.parametrize( "arch_name", [ - "mobilenet_v3_small_orientation", + "mobilenet_v3_small_crop_orientation", ], ) def test_classification_zoo(arch_name): @@ -92,7 +92,7 @@ def test_classification_zoo(arch_name): with pytest.raises(ValueError): predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) # object check - assert isinstance(predictor, CropOrientationPredictor) + assert isinstance(predictor, OrientationPredictor) input_tensor = torch.rand((batch_size, 3, 128, 128)) if torch.cuda.is_available(): predictor.model.cuda() @@ -112,14 +112,15 @@ def test_classification_zoo(arch_name): def test_crop_orientation_model(mock_text_box): text_box_0 = cv2.imread(mock_text_box) - text_box_90 = np.rot90(text_box_0, 1) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) text_box_180 = np.rot90(text_box_0, 2) - text_box_270 = np.rot90(text_box_0, 3) - classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True) - assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[0] == [0, 1, 2, 3] + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] # 270 degrees is equivalent to -90 degrees - assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[1] == [0, 90, 180, -90] - assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_90, text_box_180, text_box_270])[2]) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) @pytest.mark.parametrize( @@ -134,7 +135,7 @@ def test_crop_orientation_model(mock_text_box): ["magc_resnet31", (3, 32, 32), (126,)], ["mobilenet_v3_small", (3, 32, 32), (126,)], ["mobilenet_v3_large", (3, 32, 32), (126,)], - ["mobilenet_v3_small_orientation", (3, 128, 128), (4,)], + ["mobilenet_v3_small_crop_orientation", (3, 128, 128), (4,)], ["vit_s", (3, 32, 32), (126,)], ["vit_b", (3, 32, 32), (126,)], ["textnet_tiny", (3, 32, 32), (126,)], diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py index 221515ca71..dd92230738 100644 --- a/tests/tensorflow/test_models_classification_tf.py +++ b/tests/tensorflow/test_models_classification_tf.py @@ -9,7 +9,7 @@ import tensorflow as tf from doctr.models import classification -from doctr.models.classification.predictor import CropOrientationPredictor +from doctr.models.classification.predictor import OrientationPredictor from doctr.models.utils import export_model_to_onnx system_available_memory = int(psutil.virtual_memory().available / 1024**3) @@ -50,7 +50,7 @@ def test_classification_architectures(arch_name, input_shape, output_size): @pytest.mark.parametrize( "arch_name, input_shape", [ - ["mobilenet_v3_small_orientation", (128, 128, 3)], + ["mobilenet_v3_small_crop_orientation", (128, 128, 3)], ], ) def test_classification_models(arch_name, input_shape): @@ -67,7 +67,7 @@ def test_classification_models(arch_name, input_shape): @pytest.mark.parametrize( "arch_name", [ - "mobilenet_v3_small_orientation", + "mobilenet_v3_small_crop_orientation", ], ) def test_classification_zoo(arch_name): @@ -77,7 +77,7 @@ def test_classification_zoo(arch_name): with pytest.raises(ValueError): predictor = classification.zoo.crop_orientation_predictor(arch="wrong_model", pretrained=False) # object check - assert isinstance(predictor, CropOrientationPredictor) + assert isinstance(predictor, OrientationPredictor) input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1) out = predictor(input_tensor) class_idxs, classes, confs = out[0], out[1], out[2] @@ -91,14 +91,15 @@ def test_classification_zoo(arch_name): def test_crop_orientation_model(mock_text_box): text_box_0 = cv2.imread(mock_text_box) - text_box_90 = np.rot90(text_box_0, 1) + # rotates counter-clockwise + text_box_270 = np.rot90(text_box_0, 1) text_box_180 = np.rot90(text_box_0, 2) - text_box_270 = np.rot90(text_box_0, 3) - classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True) - assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[0] == [0, 1, 2, 3] + text_box_90 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_crop_orientation", pretrained=True) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] # 270 degrees is equivalent to -90 degrees - assert classifier([text_box_0, text_box_90, text_box_180, text_box_270])[1] == [0, 90, 180, -90] - assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_90, text_box_180, text_box_270])[2]) + assert classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all(isinstance(pred, float) for pred in classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2]) # temporarily fix to avoid killing the CI (tf2onnx v1.14 memory leak issue) @@ -109,7 +110,7 @@ def test_crop_orientation_model(mock_text_box): ["vgg16_bn_r", (32, 32, 3), (126,)], ["mobilenet_v3_small", (512, 512, 3), (126,)], ["mobilenet_v3_large", (512, 512, 3), (126,)], - ["mobilenet_v3_small_orientation", (128, 128, 3), (4,)], + ["mobilenet_v3_small_crop_orientation", (128, 128, 3), (4,)], ["resnet18", (32, 32, 3), (126,)], ["vit_s", (32, 32, 3), (126,)], ["textnet_tiny", (32, 32, 3), (126,)], @@ -162,7 +163,7 @@ def test_models_onnx_export(arch_name, input_shape, output_size): # Model batch_size = 2 tf.keras.backend.clear_session() - if arch_name == "mobilenet_v3_small_orientation": + if arch_name == "mobilenet_v3_small_crop_orientation": model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) else: model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape)