Skip to content

Commit

Permalink
update some parts
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Oct 5, 2023
1 parent 28f4981 commit 2e30672
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 72 deletions.
55 changes: 12 additions & 43 deletions doctr/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from langdetect import LangDetectException, detect_langs

__all__ = ["estimate_orientation", "get_bitmap_angle", "get_language", "invert_data_structure"]
__all__ = ["estimate_orientation", "get_language", "invert_data_structure"]


def get_max_width_length_ratio(contour: np.ndarray) -> float:
Expand All @@ -26,25 +26,32 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float:
return max(w / h, h / w)


def estimate_orientation(seq_map: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float:
def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float:
"""Estimate the angle of the general document orientation based on the
lines of the document and the assumption that they should be horizontal.
Args:
seq_mab: the binarized image of the document
img: the img or bitmap to analyze (H, W, C)
n_ct: the number of contours used for the orientation estimation
ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines
Returns:
the angle of the general document orientation
"""

if np.max(img) <= 1 and np.min(img) >= 0 or (np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 1):
thresh = img.astype(np.uint8)
if np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 3:
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_img = cv2.medianBlur(gray_img, 5)
thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]

# try to merge words in lines
(h, w) = seq_map.shape[:2]
(h, w) = img.shape[:2]
k_x = max(1, (floor(w / 100)))
k_y = max(1, (floor(h / 100)))
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y))
thresh = cv2.dilate(seq_map, kernel, iterations=1)
thresh = cv2.dilate(thresh, kernel, iterations=1)

# extract contours
contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
Expand All @@ -66,44 +73,6 @@ def estimate_orientation(seq_map: np.ndarray, n_ct: int = 50, ratio_threshold_fo
return -median_low(angles)


def get_bitmap_angle(bitmap: np.ndarray, n_ct: int = 20, std_max: float = 3.0) -> float:
"""From a binarized segmentation map, find contours and fit min area rectangles to determine page angle
Args:
bitmap: binarized segmentation map
n_ct: number of contours to use to fit page angle
std_max: maximum deviation of the angle distribution to consider the mean angle reliable
Returns:
The angle of the page
"""
# Find all contours on binarized seg map
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
# Sort contours
contours = sorted(contours, key=cv2.contourArea, reverse=True)

# Find largest contours and fit angles
# Track heights and widths to find aspect ratio (determine is rotation is clockwise)
angles, heights, widths = [], [], []
for ct in contours[:n_ct]:
_, (w, h), alpha = cv2.minAreaRect(ct)
widths.append(w)
heights.append(h)
angles.append(alpha)

if np.std(angles) > std_max:
# Edge case with angles of both 0 and 90°, or multi_oriented docs
angle = 0.0
else:
angle = -np.mean(angles)
# Determine rotation direction (clockwise/counterclockwise)
# Angle coverage: [-90°, +90°], half of the quadrant
if np.sum(widths) < np.sum(heights): # CounterClockwise
angle = 90 + angle

return angle


def rectify_crops(
crops: List[np.ndarray],
orientations: List[int],
Expand Down
4 changes: 2 additions & 2 deletions doctr/models/detection/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def forward(
self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches
]
preds = [pred for batch in predicted_batches for pred in batch["preds"]]
seq_maps = [
seg_maps = [
pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"]
]
if return_maps:
return preds, seq_maps
return preds, seg_maps
return preds
4 changes: 2 additions & 2 deletions doctr/models/detection/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __call__(
]

preds = [pred for batch in predicted_batches for pred in batch["preds"]]
seq_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]]
if return_maps:
return preds, seq_maps
return preds, seg_maps
return preds
6 changes: 3 additions & 3 deletions doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def forward(
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seq_maps = [
seg_maps = [
np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
np.uint8
)
for out_map in out_maps
]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps]
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
]
Expand All @@ -92,7 +92,7 @@ def forward(
origin_page_orientations = (
origin_page_orientations
if self.detect_orientation
else [estimate_orientation(seq_map) for seq_map in seq_maps]
else [estimate_orientation(seq_map) for seq_map in seg_maps]
)
pages = [
rotate_image(page, -angle, expand=False) # type: ignore[arg-type]
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/kie_predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def __call__(
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seq_maps = [
seg_maps = [
np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
np.uint8
)
for out_map in out_maps
]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps]
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
]
Expand All @@ -92,7 +92,7 @@ def __call__(
origin_page_orientations = (
origin_page_orientations
if self.detect_orientation
else [estimate_orientation(seq_map) for seq_map in seq_maps]
else [estimate_orientation(seq_map) for seq_map in seg_maps]
)
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
# Forward again to get predictions on straight pages
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def forward(
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seq_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps]
seg_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps]
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
]
Expand All @@ -87,7 +87,7 @@ def forward(
origin_page_orientations = (
origin_page_orientations
if self.detect_orientation
else [estimate_orientation(seq_map) for seq_map in seq_maps]
else [estimate_orientation(seq_map) for seq_map in seg_maps]
)
pages = [
rotate_image(page, -angle, expand=False) # type: ignore[arg-type]
Expand Down
6 changes: 3 additions & 3 deletions doctr/models/predictor/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def __call__(
loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seq_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps]
seg_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps]
if self.detect_orientation:
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps]
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
]
Expand All @@ -87,7 +87,7 @@ def __call__(
origin_page_orientations = (
origin_page_orientations
if self.detect_orientation
else [estimate_orientation(seq_map) for seq_map in seq_maps]
else [estimate_orientation(seq_map) for seq_map in seg_maps]
)
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
# forward again to get predictions on straight pages
Expand Down
4 changes: 2 additions & 2 deletions scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def main(args):

out = model(doc)

for page, img in zip(out.pages, doc):
page.show(img, block=not args.noblock, interactive=not args.static)
for page in out.pages:
page.show(block=not args.noblock, interactive=not args.static)


def parse_args():
Expand Down
28 changes: 19 additions & 9 deletions tests/common/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import requests

from doctr.io import reader
from doctr.models._utils import estimate_orientation, get_bitmap_angle, get_language, invert_data_structure
from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
from doctr.utils import geometry


Expand All @@ -23,22 +23,32 @@ def mock_image(tmpdir_factory):

@pytest.fixture(scope="function")
def mock_bitmap(mock_image):
bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY))
bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY) / 255.0)
bitmap = np.expand_dims(bitmap, axis=-1)
return bitmap


def test_get_bitmap_angle(mock_bitmap):
angle = get_bitmap_angle(mock_bitmap)
assert abs(angle - 30.0) < 1.0
def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip):
assert estimate_orientation(mock_image * 0) == 0

# test binarized image
angle = estimate_orientation(mock_bitmap)
assert abs(angle - 30.0) < 1.0

def test_estimate_orientation(mock_bitmap):
assert estimate_orientation(mock_bitmap * 0) == 0
angle = estimate_orientation(mock_bitmap * 255)
assert abs(angle - 30.0) < 1.0

angle = estimate_orientation(mock_bitmap)
angle = estimate_orientation(mock_image)
assert abs(angle - 30.0) < 1.0

rotated = geometry.rotate_image(mock_bitmap, -angle)
rotated = geometry.rotate_image(mock_image, -angle)
angle_rotated = estimate_orientation(rotated)
assert abs(angle_rotated) < 1.0

mock_tilted_payslip = reader.read_img_as_numpy(mock_tilted_payslip)
assert (estimate_orientation(mock_tilted_payslip) - 30.0) < 1.0

rotated = geometry.rotate_image(mock_tilted_payslip, -30, expand=True)
angle_rotated = estimate_orientation(rotated)
assert abs(angle_rotated) < 1.0

Expand Down
6 changes: 5 additions & 1 deletion tests/pytorch/test_models_detection_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ def test_detection_zoo(arch_name):
input_tensor = input_tensor.cuda()

with torch.no_grad():
out = predictor(input_tensor)
out, seq_maps = predictor(input_tensor, return_maps=True)
assert all(isinstance(boxes, dict) for boxes in out)
assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out)
assert all(isinstance(seq_map, np.ndarray) for seq_map in seq_maps)
assert all(seq_map.shape[:2] == (1024, 1024) for seq_map in seq_maps)
# check that all values in the seq_maps are between 0 and 1
assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps)


def test_erode():
Expand Down
6 changes: 5 additions & 1 deletion tests/tensorflow/test_models_detection_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,13 @@ def test_detection_zoo(arch_name):
# object check
assert isinstance(predictor, DetectionPredictor)
input_tensor = tf.random.uniform(shape=[2, 1024, 1024, 3], minval=0, maxval=1)
out = predictor(input_tensor)
out, seq_maps = predictor(input_tensor, return_maps=True)
assert all(isinstance(boxes, dict) for boxes in out)
assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out)
assert all(isinstance(seq_map, np.ndarray) for seq_map in seq_maps)
assert all(seq_map.shape[:2] == (1024, 1024) for seq_map in seq_maps)
# check that all values in the seq_maps are between 0 and 1
assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps)


def test_detection_zoo_error():
Expand Down

0 comments on commit 2e30672

Please sign in to comment.