Skip to content

Commit

Permalink
unify tf and pt test
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Aug 11, 2023
1 parent 9f74a1f commit 95c7286
Showing 1 changed file with 108 additions and 0 deletions.
108 changes: 108 additions & 0 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from torch import nn

from doctr import models
from doctr.file_utils import CLASS_NAME
from doctr.io import Document, DocumentFile
from doctr.io.elements import KIEDocument
from doctr.models import detection, recognition
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.detection.zoo import detection_predictor
from doctr.models.kie_predictor import KIEPredictor
from doctr.models.predictor import OCRPredictor
from doctr.models.preprocessor import PreProcessor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.models.recognition.zoo import recognition_predictor


@pytest.mark.parametrize(
Expand Down Expand Up @@ -70,6 +73,57 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
assert out.pages[0].orientation["value"] == orientation


def test_trained_ocr_predictor(mock_tilted_payslip):
doc = DocumentFile.from_images(mock_tilted_payslip)

det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True)
reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128)

predictor = OCRPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
)

out = predictor(doc)

assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr."
geometry_mr = np.array(
[[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]]
)
assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr)

assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised"
geometry_revised = np.array(
[[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]]
)
assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised)

det_predictor = detection_predictor(
"db_resnet50",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

predictor = OCRPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

out = predictor(doc)

assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr."


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
[
Expand Down Expand Up @@ -127,6 +181,60 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
assert out.pages[0].orientation["value"] == orientation


def test_trained_kie_predictor(mock_tilted_payslip):
doc = DocumentFile.from_images(mock_tilted_payslip)

det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True)
reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128)

predictor = KIEPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
)

out = predictor(doc)

assert isinstance(out, KIEDocument)
assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr."
geometry_mr = np.array(
[[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]]
)
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr)

print(out.pages[0].predictions[CLASS_NAME])
assert out.pages[0].predictions[CLASS_NAME][7].value == "revised"
geometry_revised = np.array(
[[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]]
)
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][7].geometry), geometry_revised)

det_predictor = detection_predictor(
"db_resnet50",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

predictor = KIEPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

out = predictor(doc)

assert isinstance(out, KIEDocument)
assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr."


def _test_predictor(predictor):
# Output checks
assert isinstance(predictor, OCRPredictor)
Expand Down

0 comments on commit 95c7286

Please sign in to comment.