diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py index e2b26b5b55..cefb77176f 100644 --- a/tests/pytorch/test_models_zoo_pt.py +++ b/tests/pytorch/test_models_zoo_pt.py @@ -3,14 +3,17 @@ from torch import nn from doctr import models +from doctr.file_utils import CLASS_NAME from doctr.io import Document, DocumentFile from doctr.io.elements import KIEDocument from doctr.models import detection, recognition from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.detection.zoo import detection_predictor from doctr.models.kie_predictor import KIEPredictor from doctr.models.predictor import OCRPredictor from doctr.models.preprocessor import PreProcessor from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.zoo import recognition_predictor @pytest.mark.parametrize( @@ -70,6 +73,57 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa assert out.pages[0].orientation["value"] == orientation +def test_trained_ocr_predictor(mock_tilted_payslip): + doc = DocumentFile.from_images(mock_tilted_payslip) + + det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + geometry_mr = np.array( + [[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]] + ) + assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr) + + assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised" + geometry_revised = np.array( + [[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]] + ) + assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr." + + @pytest.mark.parametrize( "assume_straight_pages, straighten_pages", [ @@ -127,6 +181,60 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa assert out.pages[0].orientation["value"] == orientation +def test_trained_kie_predictor(mock_tilted_payslip): + doc = DocumentFile.from_images(mock_tilted_payslip) + + det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True) + reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=False, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + geometry_mr = np.array( + [[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]] + ) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr) + + print(out.pages[0].predictions[CLASS_NAME]) + assert out.pages[0].predictions[CLASS_NAME][7].value == "revised" + geometry_revised = np.array( + [[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]] + ) + assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][7].geometry), geometry_revised) + + det_predictor = detection_predictor( + "db_resnet50", + pretrained=True, + batch_size=2, + assume_straight_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + predictor = KIEPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=True, + straighten_pages=True, + preserve_aspect_ratio=True, + symmetric_pad=True, + ) + + out = predictor(doc) + + assert isinstance(out, KIEDocument) + assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr." + + def _test_predictor(predictor): # Output checks assert isinstance(predictor, OCRPredictor)