Skip to content

Commit

Permalink
[predictor] aspect ratio true by default (#1279)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Aug 11, 2023
1 parent aa8e6a1 commit 5023af9
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 8 deletions.
2 changes: 1 addition & 1 deletion doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
detect_language: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
detect_language: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
detect_language: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
reco_predictor: RecognitionPredictor,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
detect_language: bool = False,
Expand Down
8 changes: 4 additions & 4 deletions doctr/models/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _predictor(
pretrained: bool,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
det_bs: int = 2,
reco_bs: int = 128,
Expand Down Expand Up @@ -64,7 +64,7 @@ def ocr_predictor(
pretrained: bool = False,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
export_as_straight_boxes: bool = False,
detect_orientation: bool = False,
Expand Down Expand Up @@ -124,7 +124,7 @@ def _kie_predictor(
pretrained: bool,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
det_bs: int = 2,
reco_bs: int = 128,
Expand Down Expand Up @@ -169,7 +169,7 @@ def kie_predictor(
pretrained: bool = False,
pretrained_backbone: bool = True,
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
export_as_straight_boxes: bool = False,
detect_orientation: bool = False,
Expand Down
108 changes: 108 additions & 0 deletions tests/pytorch/test_models_zoo_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from torch import nn

from doctr import models
from doctr.file_utils import CLASS_NAME
from doctr.io import Document, DocumentFile
from doctr.io.elements import KIEDocument
from doctr.models import detection, recognition
from doctr.models.detection.predictor import DetectionPredictor
from doctr.models.detection.zoo import detection_predictor
from doctr.models.kie_predictor import KIEPredictor
from doctr.models.predictor import OCRPredictor
from doctr.models.preprocessor import PreProcessor
from doctr.models.recognition.predictor import RecognitionPredictor
from doctr.models.recognition.zoo import recognition_predictor


@pytest.mark.parametrize(
Expand Down Expand Up @@ -70,6 +73,57 @@ def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
assert out.pages[0].orientation["value"] == orientation


def test_trained_ocr_predictor(mock_tilted_payslip):
doc = DocumentFile.from_images(mock_tilted_payslip)

det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True)
reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128)

predictor = OCRPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
)

out = predictor(doc)

assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr."
geometry_mr = np.array(
[[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]]
)
assert np.allclose(np.array(out.pages[0].blocks[0].lines[0].words[0].geometry), geometry_mr)

assert out.pages[0].blocks[1].lines[0].words[-1].value == "revised"
geometry_revised = np.array(
[[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]]
)
assert np.allclose(np.array(out.pages[0].blocks[1].lines[0].words[-1].geometry), geometry_revised)

det_predictor = detection_predictor(
"db_resnet50",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

predictor = OCRPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

out = predictor(doc)

assert out.pages[0].blocks[0].lines[0].words[0].value == "Mr."


@pytest.mark.parametrize(
"assume_straight_pages, straighten_pages",
[
Expand Down Expand Up @@ -127,6 +181,60 @@ def test_kiepredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pa
assert out.pages[0].orientation["value"] == orientation


def test_trained_kie_predictor(mock_tilted_payslip):
doc = DocumentFile.from_images(mock_tilted_payslip)

det_predictor = detection_predictor("db_resnet50", pretrained=True, batch_size=2, assume_straight_pages=True)
reco_predictor = recognition_predictor("crnn_vgg16_bn", pretrained=True, batch_size=128)

predictor = KIEPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
)

out = predictor(doc)

assert isinstance(out, KIEDocument)
assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr."
geometry_mr = np.array(
[[0.08563021, 0.35584526], [0.11464554, 0.34078913], [0.1274898, 0.36012764], [0.09847447, 0.37518377]]
)
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][0].geometry), geometry_mr)

print(out.pages[0].predictions[CLASS_NAME])
assert out.pages[0].predictions[CLASS_NAME][7].value == "revised"
geometry_revised = np.array(
[[0.50422498, 0.19551784], [0.55741975, 0.16791493], [0.56705294, 0.18241881], [0.51385817, 0.21002172]]
)
assert np.allclose(np.array(out.pages[0].predictions[CLASS_NAME][7].geometry), geometry_revised)

det_predictor = detection_predictor(
"db_resnet50",
pretrained=True,
batch_size=2,
assume_straight_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

predictor = KIEPredictor(
det_predictor,
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=True,
symmetric_pad=True,
)

out = predictor(doc)

assert isinstance(out, KIEDocument)
assert out.pages[0].predictions[CLASS_NAME][0].value == "Mr."


def _test_predictor(predictor):
# Output checks
assert isinstance(predictor, OCRPredictor)
Expand Down
2 changes: 2 additions & 0 deletions tests/tensorflow/test_models_zoo_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def test_trained_ocr_predictor(mock_tilted_payslip):
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
)

out = predictor(doc)
Expand Down Expand Up @@ -189,6 +190,7 @@ def test_trained_kie_predictor(mock_tilted_payslip):
reco_predictor,
assume_straight_pages=True,
straighten_pages=True,
preserve_aspect_ratio=False,
)

out = predictor(doc)
Expand Down

0 comments on commit 5023af9

Please sign in to comment.