diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py index 42f9142497..530590bc61 100644 --- a/doctr/models/predictor/base.py +++ b/doctr/models/predictor/base.py @@ -8,7 +8,7 @@ import numpy as np from doctr.models.builder import DocumentBuilder -from doctr.utils.geometry import extract_crops, extract_rcrops, rotate_image +from doctr.utils.geometry import extract_crops, extract_rcrops, remove_image_padding, rotate_image from .._utils import estimate_orientation, rectify_crops, rectify_loc_preds from ..classification import crop_orientation_predictor, page_orientation_predictor @@ -107,8 +107,8 @@ def _straighten_pages( ] ) return [ - # expand if height and width are not equal - rotate_image(page, angle, expand=page.shape[0] != page.shape[1]) + # expand if height and width are not equal, then remove the padding + remove_image_padding(rotate_image(page, angle, expand=page.shape[0] != page.shape[1])) for page, angle in zip(pages, origin_pages_orientations) ] diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py index d16ac3df86..21ad0dd7f4 100644 --- a/doctr/utils/geometry.py +++ b/doctr/utils/geometry.py @@ -20,6 +20,7 @@ "rotate_boxes", "compute_expanded_shape", "rotate_image", + "remove_image_padding", "estimate_page_angle", "convert_to_relative_coords", "rotate_abs_geoms", @@ -351,6 +352,26 @@ def rotate_image( return rot_img +def remove_image_padding(image: np.ndarray) -> np.ndarray: + """Remove black border padding from an image + + Args: + ---- + image: numpy tensor to remove padding from + + Returns: + ------- + Image with padding removed + """ + # Find the bounding box of the non-black region + rows = np.any(image, axis=1) + cols = np.any(image, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return image[rmin : rmax + 1, cmin : cmax + 1] + + def estimate_page_angle(polys: np.ndarray) -> float: """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the estimated angle ccw in degrees diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py index afeed8a87c..d2524a6ab7 100644 --- a/tests/common/test_utils_geometry.py +++ b/tests/common/test_utils_geometry.py @@ -142,6 +142,17 @@ def test_rotate_image(): assert rotated[0, :, 0].sum() <= 1 +def test_remove_image_padding(): + img = np.ones((32, 64, 3), dtype=np.float32) + padded = np.pad(img, ((10, 10), (20, 20), (0, 0))) + cropped = geometry.remove_image_padding(padded) + assert np.all(cropped == img) + + # No padding + cropped = geometry.remove_image_padding(img) + assert np.all(cropped == img) + + @pytest.mark.parametrize( "abs_geoms, img_size, rel_geoms", [