diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b283e9a6a..f88af5972e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: - id: no-commit-to-branch args: ['--branch', 'main'] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.7.4 + rev: v0.8.1 hooks: - id: ruff args: [ --fix ] diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index db632701ba..073172efbc 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -31,7 +31,11 @@ Advantages: .. code:: python3 import torch - predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True).cuda().half() + predictor = ocr_predictor( + reco_arch="crnn_mobilenet_v3_small", + det_arch="linknet_resnet34", + pretrained=True + ).cuda().half() res = predictor(doc) .. tab:: TensorFlow @@ -41,8 +45,63 @@ Advantages: import tensorflow as tf from tensorflow.keras import mixed_precision mixed_precision.set_global_policy('mixed_float16') - predictor = ocr_predictor(reco_arch="crnn_mobilenet_v3_small", det_arch="linknet_resnet34", pretrained=True) - + predictor = ocr_predictor( + reco_arch="crnn_mobilenet_v3_small", + det_arch="linknet_resnet34", + pretrained=True + ) + + +Compiling your models (PyTorch only) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**NOTE:** + +- This feature is only available if you use PyTorch as backend. +- The recognition architecture `master` is not supported for model compilation yet. +- We provide only official support for the default (`inductor`) backend, but you can try other backends, configurations depending on your hardware and requirements as well. + +Compiling your PyTorch models with `torch.compile` optimizes the model by converting it to a graph representation and applying backends that can improve performance. +This process can make inference faster and reduce memory overhead during execution. + +Further information can be found in the `PyTorch documentation `_. + +.. code:: + + import torch + from doctr.models import ( + ocr_predictor, + vitstr_small, + fast_base, + mobilenet_v3_small_crop_orientation, + mobilenet_v3_small_page_orientation, + crop_orientation_predictor, + page_orientation_predictor + ) + + # Compile the models + detection_model = torch.compile( + fast_base(pretrained=True).eval() + ) + recognition_model = torch.compile( + vitstr_small(pretrained=True).eval() + ) + crop_orientation_model = torch.compile( + mobilenet_v3_small_crop_orientation(pretrained=True).eval() + ) + page_orientation_model = torch.compile( + mobilenet_v3_small_page_orientation(pretrained=True).eval() + ) + + predictor = models.ocr_predictor( + detection_model, recognition_model, assume_straight_pages=False + ) + # NOTE: Only required for non-straight pages (`assume_straight_pages=False`) and non-disabled orientation classification + # Set the orientation predictors + predictor.crop_orientation_predictor = crop_orientation_predictor(crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(page_orientation_model) + + compiled_out = predictor(doc) Export to ONNX ^^^^^^^^^^^^^^ @@ -64,7 +123,11 @@ It defines a common format for representing models, including the network struct input_shape = (3, 32, 128) model = vitstr_small(pretrained=True, exportable=True) dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32) - model_path = export_model_to_onnx(model, model_name="vitstr.onnx, dummy_input=dummy_input) + model_path = export_model_to_onnx( + model, + model_name="vitstr.onnx", + dummy_input=dummy_input + ) .. tab:: TensorFlow @@ -78,7 +141,11 @@ It defines a common format for representing models, including the network struct input_shape = (32, 128, 3) model = vitstr_small(pretrained=True, exportable=True) dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")] - model_path, output = export_model_to_onnx(model, model_name="vitstr.onnx", dummy_input=dummy_input) + model_path, output = export_model_to_onnx( + model, + model_name="vitstr.onnx", + dummy_input=dummy_input + ) Using your ONNX exported model diff --git a/docs/source/using_doctr/using_models.rst b/docs/source/using_doctr/using_models.rst index 40a820cee5..e42dea0881 100644 --- a/docs/source/using_doctr/using_models.rst +++ b/docs/source/using_doctr/using_models.rst @@ -298,7 +298,7 @@ For instance, this snippet instantiates an end-to-end ocr_predictor working with .. code:: python3 - from doctr.model import ocr_predictor + from doctr.models import ocr_predictor model = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True) @@ -309,7 +309,7 @@ Additionally, you can change the batch size of the underlying detection and reco .. code:: python3 - from doctr.model import ocr_predictor + from doctr.models import ocr_predictor model = ocr_predictor(pretrained=True, det_bs=4, reco_bs=1024) To modify the output structure you can pass the following arguments to the predictor which will be handled by the underlying `DocumentBuilder`: @@ -322,7 +322,7 @@ For example to disable the automatic grouping of lines into blocks: .. code:: python3 - from doctr.model import ocr_predictor + from doctr.models import ocr_predictor model = ocr_predictor(pretrained=True, resolve_blocks=False) @@ -477,7 +477,7 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh .. code:: python3 - from doctr.model import ocr_predictor + from doctr.models import ocr_predictor model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_page_orientation=True) @@ -489,7 +489,7 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh .. code:: python3 - from doctr.model import ocr_predictor + from doctr.models import ocr_predictor model = ocr_predictor(pretrained=True, assume_straight_pages=False, disable_crop_orientation=True) @@ -497,7 +497,7 @@ This will only have an effect with `assume_straight_pages=False` and/or `straigh .. code:: python3 - from doctr.model import ocr_predictor + from doctr.models import ocr_predictor class CustomHook: def __call__(self, loc_preds): diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 7cce3083ea..c147e7f104 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -5,7 +5,7 @@ from typing import Any -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available from .. import classification from ..preprocessor import PreProcessor @@ -48,7 +48,14 @@ def _orientation_predictor( # Load directly classifier from backbone _model = classification.__dict__[arch](pretrained=pretrained) else: - if not isinstance(arch, classification.MobileNetV3): + allowed_archs = [classification.MobileNetV3] + if is_torch_available(): + # Adding the type for torch compiled models to the allowed architectures + from doctr.models.utils import _CompiledModule + + allowed_archs.append(_CompiledModule) + + if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index ddd8cfa595..cad6a74aaf 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -205,11 +205,16 @@ def forward( out["out_map"] = prob_map if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]: + return [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + # Post-process boxes (keep only text predictions) - out["preds"] = [ - dict(zip(self.class_names, preds)) - for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) - ] + out["preds"] = _postprocess(prob_map) if target is not None: thresh_map = self.thresh_head(feat_concat) diff --git a/doctr/models/detection/fast/pytorch.py b/doctr/models/detection/fast/pytorch.py index 91218dba11..595004ba98 100644 --- a/doctr/models/detection/fast/pytorch.py +++ b/doctr/models/detection/fast/pytorch.py @@ -196,11 +196,16 @@ def forward( out["out_map"] = prob_map if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]: + return [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + # Post-process boxes (keep only text predictions) - out["preds"] = [ - dict(zip(self.class_names, preds)) - for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) - ] + out["preds"] = _postprocess(prob_map) if target is not None: loss = self.compute_loss(logits, target) diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py index 88cb24204b..ff6860718e 100644 --- a/doctr/models/detection/linknet/pytorch.py +++ b/doctr/models/detection/linknet/pytorch.py @@ -183,11 +183,16 @@ def forward( out["out_map"] = prob_map if target is None or return_preds: - # Post-process boxes - out["preds"] = [ - dict(zip(self.class_names, preds)) - for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) - ] + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]: + return [ + dict(zip(self.class_names, preds)) + for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + + # Post-process boxes (keep only text predictions) + out["preds"] = _postprocess(prob_map) if target is not None: loss = self.compute_loss(logits, target) diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index da33b846ff..6d87b6b1a9 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -56,7 +56,14 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, if isinstance(_model, detection.FAST): _model = reparameterize(_model) else: - if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)): + allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST] + if is_torch_available(): + # Adding the type for torch compiled models to the allowed architectures + from doctr.models.utils import _CompiledModule + + allowed_archs.append(_CompiledModule) + + if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py index 1c0641bba2..bc77c98a24 100644 --- a/doctr/models/recognition/crnn/pytorch.py +++ b/doctr/models/recognition/crnn/pytorch.py @@ -213,8 +213,13 @@ def forward( out["out_map"] = logits if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]: + return self.postprocessor(logits) + # Post-process boxes - out["preds"] = self.postprocessor(logits) + out["preds"] = _postprocess(logits) if target is not None: out["loss"] = self.compute_loss(logits, target) diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py index 3de463d09b..01b1df0735 100644 --- a/doctr/models/recognition/master/pytorch.py +++ b/doctr/models/recognition/master/pytorch.py @@ -209,7 +209,13 @@ def forward( out["out_map"] = logits if return_preds: - out["preds"] = self.postprocessor(logits) + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]: + return self.postprocessor(logits) + + # Post-process boxes + out["preds"] = _postprocess(logits) return out diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 414739436a..a2ee524ced 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -372,8 +372,13 @@ def forward( out["out_map"] = logits if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(logits: torch.Tensor) -> list[tuple[str, float]]: + return self.postprocessor(logits) + # Post-process boxes - out["preds"] = self.postprocessor(logits) + out["preds"] = _postprocess(logits) if target is not None: out["loss"] = loss diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py index caf1900575..fbbae121d4 100644 --- a/doctr/models/recognition/sar/pytorch.py +++ b/doctr/models/recognition/sar/pytorch.py @@ -262,8 +262,13 @@ def forward( out["out_map"] = decoded_features if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]: + return self.postprocessor(decoded_features) + # Post-process boxes - out["preds"] = self.postprocessor(decoded_features) + out["preds"] = _postprocess(decoded_features) if target is not None: out["loss"] = self.compute_loss(decoded_features, gt, seq_len) diff --git a/doctr/models/recognition/utils.py b/doctr/models/recognition/utils.py index f6b299ade8..daa3c6d2d0 100644 --- a/doctr/models/recognition/utils.py +++ b/doctr/models/recognition/utils.py @@ -22,7 +22,7 @@ def merge_strings(a: str, b: str, dil_factor: float) -> str: A merged character sequence. Example:: - >>> from doctr.model.recognition.utils import merge_sequences + >>> from doctr.models.recognition.utils import merge_sequences >>> merge_sequences('abcd', 'cdefgh', 1.4) 'abcdefgh' >>> merge_sequences('abcdi', 'cdefgh', 1.4) @@ -70,7 +70,7 @@ def merge_multi_strings(seq_list: list[str], dil_factor: float) -> str: A merged character sequence Example:: - >>> from doctr.model.recognition.utils import merge_multi_sequences + >>> from doctr.models.recognition.utils import merge_multi_sequences >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4) 'abcdefghijkl' """ diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index db763d53c7..7bd57793cc 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -107,8 +107,13 @@ def forward( out["out_map"] = decoded_features if target is None or return_preds: + # Disable for torch.compile compatibility + @torch.compiler.disable # type: ignore[attr-defined] + def _postprocess(decoded_features: torch.Tensor) -> list[tuple[str, float]]: + return self.postprocessor(decoded_features) + # Post-process boxes - out["preds"] = self.postprocessor(decoded_features) + out["preds"] = _postprocess(decoded_features) if target is not None: out["loss"] = self.compute_loss(decoded_features, gt, seq_len) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index f60431441c..56327f8a65 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -5,7 +5,7 @@ from typing import Any -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_torch_available from doctr.models.preprocessor import PreProcessor from .. import recognition @@ -35,9 +35,14 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict pretrained=pretrained, pretrained_backbone=kwargs.get("pretrained_backbone", True) ) else: - if not isinstance( - arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq) - ): + allowed_archs = [recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq] + if is_torch_available(): + # Adding the type for torch compiled models to the allowed architectures + from doctr.models.utils import _CompiledModule + + allowed_archs.append(_CompiledModule) + + if not isinstance(arch, tuple(allowed_archs)): raise ValueError(f"unknown architecture: {type(arch)}") _model = arch diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py index c9ea3a43ca..0582a4de2b 100644 --- a/doctr/models/utils/pytorch.py +++ b/doctr/models/utils/pytorch.py @@ -18,8 +18,12 @@ "export_model_to_onnx", "_copy_tensor", "_bf16_to_float32", + "_CompiledModule", ] +# torch compiled model type +_CompiledModule = torch._dynamo.eval_frame.OptimizedModule + def _copy_tensor(x: torch.Tensor) -> torch.Tensor: return x.clone().detach() diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index b3d25af173..5441181a1c 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -9,7 +9,7 @@ from doctr.models import classification from doctr.models.classification.predictor import OrientationPredictor -from doctr.models.utils import export_model_to_onnx +from doctr.models.utils import _CompiledModule, export_model_to_onnx def _test_classification(model, input_shape, output_size, batch_size=2): @@ -152,6 +152,19 @@ def test_crop_orientation_model(mock_text_box): with pytest.raises(ValueError): _ = classification.crop_orientation_predictor(classification.textnet_tiny(pretrained=True)) + # Test torch compilation + compiled_model = torch.compile(classification.mobilenet_v3_small_crop_orientation(pretrained=True)) + compiled_classifier = classification.crop_orientation_predictor(compiled_model) + + assert isinstance(compiled_model, _CompiledModule) + assert isinstance(compiled_classifier, OrientationPredictor) + assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all( + isinstance(pred, float) + for pred in compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2] + ) + def test_page_orientation_model(mock_payslip): text_box_0 = cv2.imread(mock_payslip) @@ -184,6 +197,19 @@ def test_page_orientation_model(mock_payslip): with pytest.raises(ValueError): _ = classification.page_orientation_predictor(classification.textnet_tiny(pretrained=True)) + # Test torch compilation + compiled_model = torch.compile(classification.mobilenet_v3_small_page_orientation(pretrained=True)) + compiled_classifier = classification.page_orientation_predictor(compiled_model) + + assert isinstance(compiled_model, _CompiledModule) + assert isinstance(compiled_classifier, OrientationPredictor) + assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[0] == [0, 1, 2, 3] + assert compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[1] == [0, -90, 180, 90] + assert all( + isinstance(pred, float) + for pred in compiled_classifier([text_box_0, text_box_270, text_box_180, text_box_90])[2] + ) + @pytest.mark.parametrize( "arch_name, input_shape, output_size", diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 247cdb2880..a8598eb2fa 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -8,11 +8,12 @@ import torch from doctr.file_utils import CLASS_NAME +from doctr.io import DocumentFile from doctr.models import detection from doctr.models.detection._utils import dilate, erode from doctr.models.detection.fast.pytorch import reparameterize from doctr.models.detection.predictor import DetectionPredictor -from doctr.models.utils import export_model_to_onnx +from doctr.models.utils import _CompiledModule, export_model_to_onnx @pytest.mark.parametrize("train_mode", [True, False]) @@ -186,3 +187,39 @@ def test_models_onnx_export(arch_name, input_shape, output_size): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + + +@pytest.mark.parametrize( + "arch_name", + [ + "db_resnet34", + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + "linknet_resnet34", + "linknet_resnet50", + "fast_tiny", + "fast_small", + "fast_base", + ], +) +def test_torch_compiled_models(arch_name, mock_payslip): + doc = DocumentFile.from_images([mock_payslip]) + predictor = detection.zoo.detection_predictor(arch_name, pretrained=True) + assert isinstance(predictor, DetectionPredictor) + out, seg_maps = predictor(doc, return_maps=True) + + # Compile the model + compiled_model = torch.compile(detection.__dict__[arch_name](pretrained=True).eval()) + assert isinstance(compiled_model, _CompiledModule) + compiled_predictor = detection.zoo.detection_predictor(compiled_model) + compiled_out, seg_maps = compiled_predictor(doc, return_maps=True) + + # Compare + assert all( + np.allclose(out_boxes[CLASS_NAME], compiled_out_boxes[CLASS_NAME], atol=1e-4) + for out_boxes, compiled_out_boxes in zip(out, compiled_out) + ) + assert all( + np.allclose(seg_map, compiled_seg_map, atol=1e-4) for seg_map, compiled_seg_map in zip(seg_maps, seg_maps) + ) diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py index e4df34060b..194089a37b 100644 --- a/tests/pytorch/test_models_recognition_pt.py +++ b/tests/pytorch/test_models_recognition_pt.py @@ -7,6 +7,7 @@ import pytest import torch +from doctr.io import DocumentFile from doctr.models import recognition from doctr.models.recognition.crnn.pytorch import CTCPostProcessor from doctr.models.recognition.master.pytorch import MASTERPostProcessor @@ -14,7 +15,7 @@ from doctr.models.recognition.predictor import RecognitionPredictor from doctr.models.recognition.sar.pytorch import SARPostProcessor from doctr.models.recognition.vitstr.pytorch import ViTSTRPostProcessor -from doctr.models.utils import export_model_to_onnx +from doctr.models.utils import _CompiledModule, export_model_to_onnx system_available_memory = int(psutil.virtual_memory().available / 1024**3) @@ -154,3 +155,33 @@ def test_models_onnx_export(arch_name, input_shape): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + # "master", NOTE: MASTER model isn't 100% safe compilable yet (pytorch v2.5.1) - sometimes it fails to compile. + "vitstr_small", + "vitstr_base", + "parseq", + ], +) +def test_torch_compiled_models(arch_name, mock_text_box): + doc = DocumentFile.from_images([mock_text_box]) + predictor = recognition.zoo.recognition_predictor(arch_name, pretrained=True) + assert isinstance(predictor, RecognitionPredictor) + out = predictor(doc) + + # Compile the model + compiled_model = torch.compile(recognition.__dict__[arch_name](pretrained=True).eval()) + assert isinstance(compiled_model, _CompiledModule) + compiled_predictor = recognition.zoo.recognition_predictor(compiled_model) + compiled_out = compiled_predictor(doc) + + # Compare + assert out[0][0] == compiled_out[0][0] + assert np.allclose(out[0][1], compiled_out[0][1], atol=1e-4) diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index 3ea22ca9b6..868842ce82 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import torch from torch import nn from doctr import models @@ -412,3 +413,38 @@ def test_zoo_models(det_arch, reco_arch): # passing detection model as recognition model with pytest.raises(ValueError): models.kie_predictor(reco_arch=det_model, pretrained=True) + + +@pytest.mark.parametrize( + "det_arch, reco_arch", + [ + ["fast_base", "crnn_vgg16_bn"], + ], +) +def test_end_to_end_torch_compile(det_arch, reco_arch, mock_payslip): + doc = DocumentFile.from_images(mock_payslip) + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True, assume_straight_pages=False) + out = predictor(doc) + + assert isinstance(out, Document) + + # Compile the models + detection_model = torch.compile(detection.__dict__[det_arch](pretrained=True).eval()) + recognition_model = torch.compile(recognition.__dict__[reco_arch](pretrained=True).eval()) + crop_orientation_model = torch.compile(mobilenet_v3_small_crop_orientation(pretrained=True).eval()) + page_orientation_model = torch.compile(mobilenet_v3_small_page_orientation(pretrained=True).eval()) + + predictor = models.ocr_predictor(detection_model, recognition_model, assume_straight_pages=False) + # Set the orientation predictors + # NOTE: only required for non-straight pages and non-disabled orientation classification + predictor.crop_orientation_predictor = crop_orientation_predictor(crop_orientation_model) + predictor.page_orientation_predictor = page_orientation_predictor(page_orientation_model) + compiled_out = predictor(doc) + + # Check that the number of word detections is the same + assert len(out.pages[0].blocks[0].lines[0].words) == len(compiled_out.pages[0].blocks[0].lines[0].words) + # Check that the words are the same + assert all( + word.value == compiled_out.pages[0].blocks[0].lines[0].words[i].value + for i, word in enumerate(out.pages[0].blocks[0].lines[0].words) + )