diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 530b5ff90d..3a4ca5740b 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -44,7 +44,7 @@ def __init__( reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, straighten_pages: bool = False, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index 3ec311b9d5..d2ba221d57 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -46,7 +46,7 @@ def __init__( reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, straighten_pages: bool = False, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 15ff2db995..aa8a878a93 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -44,7 +44,7 @@ def __init__( reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, straighten_pages: bool = False, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 91ad27b83f..f3ef69a685 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -46,7 +46,7 @@ def __init__( reco_predictor: RecognitionPredictor, assume_straight_pages: bool = True, straighten_pages: bool = False, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index a178989d4c..e70d49e68e 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -19,7 +19,7 @@ def _predictor( pretrained: bool, pretrained_backbone: bool = True, assume_straight_pages: bool = True, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, det_bs: int = 2, reco_bs: int = 128, @@ -64,7 +64,7 @@ def ocr_predictor( pretrained: bool = False, pretrained_backbone: bool = True, assume_straight_pages: bool = True, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, export_as_straight_boxes: bool = False, detect_orientation: bool = False, @@ -124,7 +124,7 @@ def _kie_predictor( pretrained: bool, pretrained_backbone: bool = True, assume_straight_pages: bool = True, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, det_bs: int = 2, reco_bs: int = 128, @@ -169,7 +169,7 @@ def kie_predictor( pretrained: bool = False, pretrained_backbone: bool = True, assume_straight_pages: bool = True, - preserve_aspect_ratio: bool = False, + preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, export_as_straight_boxes: bool = False, detect_orientation: bool = False, diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index e2b26b5b55..cefb77176f 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -3,14 +3,17 @@ from torch import nn from doctr import models +from doctr.file_utils import CLASS_NAME from doctr.io import Document, DocumentFile from doctr.io.elements import KIEDocument from doctr.models import detection, recognition from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor from doctr.models.predictor import OCRPredictor from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.zoo import recognition_predictor @pytest.mark.parametrize( @@ -70,6 +73,57 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa assert out.pages[0].orientation["value"] == orientation +def test_trained_ocr_predictor(mock_tilted_payslip): + doc = DocumentFile.from_images(mock_tilted_payslip) + + det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + geometry_mr = np.array( + [[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]] + ) + assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr) + + assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised" + geometry_revised = np.array( + [[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]] + ) + assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + + @pytest.mark.parametrize( "assume_straight_pages, straighten_pages", [ @@ -127,6 +181,60 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa assert out.pages[0].orientation["value"] == orientation +def test_trained_kie_predictor(mock_tilted_payslip): + doc = DocumentFile.from_images(mock_tilted_payslip) + + det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + geometry_mr = np.array( + [[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]] + ) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr) + + print(out.pages[0].predictions[CLASS_NAME]) + assert out.pages[0].predictions[CLASS_NAME][7].value == "revised" + geometry_revised = np.array( + [[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]] + ) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][7].geometry), geometry_revised) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + + def _test_predictor(predictor): # Output checks assert isinstance(predictor, OCRPredictor) diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py index 5cd4d0be5d..6d4b85e2c8 100644 --- a/tests/tensorflow/test_models_zoo_tf.py +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -83,6 +83,7 @@ def test_trained_ocr_predictor(mock_tilted_payslip): reco_predictor, assume_straight_pages=True, straighten_pages=True, + preserve_aspect_ratio=False, ) out = predictor(doc) @@ -189,6 +190,7 @@ def test_trained_kie_predictor(mock_tilted_payslip): reco_predictor, assume_straight_pages=True, straighten_pages=True, + preserve_aspect_ratio=False, ) out = predictor(doc)