diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py index 5c740197db..8ed94f345b 100644 --- a/doctr/models/_utils.py +++ b/doctr/models/_utils.py @@ -11,7 +11,7 @@ import numpy as np from langdetect import LangDetectException, detect_langs -__all__ = ["estimate_orientation", "get_bitmap_angle", "get_language", "invert_data_structure"] +__all__ = ["estimate_orientation", "get_language", "invert_data_structure"] def get_max_width_length_ratio(contour: np.ndarray) -> float: @@ -26,12 +26,12 @@ def get_max_width_length_ratio(contour: np.ndarray) -> float: return max(w / h, h / w) -def estimate_orientation(seq_map: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float: +def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float: """Estimate the angle of the general document orientation based on the lines of the document and the assumption that they should be horizontal. Args: - seq_mab: the binarized image of the document + img: the img or bitmap to analyze (H, W, C) n_ct: the number of contours used for the orientation estimation ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines @@ -39,12 +39,19 @@ def estimate_orientation(seq_map: np.ndarray, n_ct: int = 50, ratio_threshold_fo the angle of the general document orientation """ + if np.max(img) <= 1 and np.min(img) >= 0 or (np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 1): + thresh = img.astype(np.uint8) + if np.max(img) <= 255 and np.min(img) >= 0 and img.shape[-1] == 3: + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_img = cv2.medianBlur(gray_img, 5) + thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] + # try to merge words in lines - (h, w) = seq_map.shape[:2] + (h, w) = img.shape[:2] k_x = max(1, (floor(w / 100))) k_y = max(1, (floor(h / 100))) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) - thresh = cv2.dilate(seq_map, kernel, iterations=1) + thresh = cv2.dilate(thresh, kernel, iterations=1) # extract contours contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) @@ -66,44 +73,6 @@ def estimate_orientation(seq_map: np.ndarray, n_ct: int = 50, ratio_threshold_fo return -median_low(angles) -def get_bitmap_angle(bitmap: np.ndarray, n_ct: int = 20, std_max: float = 3.0) -> float: - """From a binarized segmentation map, find contours and fit min area rectangles to determine page angle - - Args: - bitmap: binarized segmentation map - n_ct: number of contours to use to fit page angle - std_max: maximum deviation of the angle distribution to consider the mean angle reliable - - Returns: - The angle of the page - """ - # Find all contours on binarized seg map - contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - # Sort contours - contours = sorted(contours, key=cv2.contourArea, reverse=True) - - # Find largest contours and fit angles - # Track heights and widths to find aspect ratio (determine is rotation is clockwise) - angles, heights, widths = [], [], [] - for ct in contours[:n_ct]: - _, (w, h), alpha = cv2.minAreaRect(ct) - widths.append(w) - heights.append(h) - angles.append(alpha) - - if np.std(angles) > std_max: - # Edge case with angles of both 0 and 90°, or multi_oriented docs - angle = 0.0 - else: - angle = -np.mean(angles) - # Determine rotation direction (clockwise/counterclockwise) - # Angle coverage: [-90°, +90°], half of the quadrant - if np.sum(widths) < np.sum(heights): # CounterClockwise - angle = 90 + angle - - return angle - - def rectify_crops( crops: List[np.ndarray], orientations: List[int], diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py index 4209e758ee..da2c11c568 100644 --- a/doctr/models/detection/predictor/pytorch.py +++ b/doctr/models/detection/predictor/pytorch.py @@ -52,9 +52,9 @@ def forward( self.model(batch, return_preds=True, return_model_output=True, **kwargs) for batch in processed_batches ] preds = [pred for batch in predicted_batches for pred in batch["preds"]] - seq_maps = [ + seg_maps = [ pred.permute(1, 2, 0).detach().cpu().numpy() for batch in predicted_batches for pred in batch["out_map"] ] if return_maps: - return preds, seq_maps + return preds, seg_maps return preds diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py index 8db9ac368c..6b48d71e43 100644 --- a/doctr/models/detection/predictor/tensorflow.py +++ b/doctr/models/detection/predictor/tensorflow.py @@ -50,7 +50,7 @@ def __call__( ] preds = [pred for batch in predicted_batches for pred in batch["preds"]] - seq_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]] + seg_maps = [pred.numpy() for batch in predicted_batches for pred in batch["out_map"]] if return_maps: - return preds, seq_maps + return preds, seg_maps return preds diff --git a/doctr/models/kie_predictor/pytorch.py b/doctr/models/kie_predictor/pytorch.py index 6d9091d6dd..ad7de0b2c4 100644 --- a/doctr/models/kie_predictor/pytorch.py +++ b/doctr/models/kie_predictor/pytorch.py @@ -75,14 +75,14 @@ def forward( loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) # Detect document rotation and rotate pages - seq_maps = [ + seg_maps = [ np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype( np.uint8 ) for out_map in out_maps ] if self.detect_orientation: - origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps] + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] orientations = [ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations ] @@ -92,7 +92,7 @@ def forward( origin_page_orientations = ( origin_page_orientations if self.detect_orientation - else [estimate_orientation(seq_map) for seq_map in seq_maps] + else [estimate_orientation(seq_map) for seq_map in seg_maps] ) pages = [ rotate_image(page, -angle, expand=False) # type: ignore[arg-type] diff --git a/doctr/models/kie_predictor/tensorflow.py b/doctr/models/kie_predictor/tensorflow.py index 812bb8640a..f4e87d21dc 100644 --- a/doctr/models/kie_predictor/tensorflow.py +++ b/doctr/models/kie_predictor/tensorflow.py @@ -75,14 +75,14 @@ def __call__( loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) # Detect document rotation and rotate pages - seq_maps = [ + seg_maps = [ np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype( np.uint8 ) for out_map in out_maps ] if self.detect_orientation: - origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps] + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] orientations = [ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations ] @@ -92,7 +92,7 @@ def __call__( origin_page_orientations = ( origin_page_orientations if self.detect_orientation - else [estimate_orientation(seq_map) for seq_map in seq_maps] + else [estimate_orientation(seq_map) for seq_map in seg_maps] ) pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] # Forward again to get predictions on straight pages diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py index 3519aeed20..3a70ee6982 100644 --- a/doctr/models/predictor/pytorch.py +++ b/doctr/models/predictor/pytorch.py @@ -75,9 +75,9 @@ def forward( loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) # Detect document rotation and rotate pages - seq_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps] + seg_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps] if self.detect_orientation: - origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps] + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] orientations = [ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations ] @@ -87,7 +87,7 @@ def forward( origin_page_orientations = ( origin_page_orientations if self.detect_orientation - else [estimate_orientation(seq_map) for seq_map in seq_maps] + else [estimate_orientation(seq_map) for seq_map in seg_maps] ) pages = [ rotate_image(page, -angle, expand=False) # type: ignore[arg-type] diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py index 93b78fe97a..980b5221f4 100644 --- a/doctr/models/predictor/tensorflow.py +++ b/doctr/models/predictor/tensorflow.py @@ -75,9 +75,9 @@ def __call__( loc_preds_dict, out_maps = self.det_predictor(pages, return_maps=True, **kwargs) # Detect document rotation and rotate pages - seq_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps] + seg_maps = [np.where(out_map > kwargs.get("bin_thresh", 0.3), 255, 0).astype(np.uint8) for out_map in out_maps] if self.detect_orientation: - origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seq_maps] + origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps] orientations = [ {"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations ] @@ -87,7 +87,7 @@ def __call__( origin_page_orientations = ( origin_page_orientations if self.detect_orientation - else [estimate_orientation(seq_map) for seq_map in seq_maps] + else [estimate_orientation(seq_map) for seq_map in seg_maps] ) pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)] # forward again to get predictions on straight pages diff --git a/scripts/analyze.py b/scripts/analyze.py index 067ed62685..2e0f19c034 100644 --- a/scripts/analyze.py +++ b/scripts/analyze.py @@ -31,8 +31,8 @@ def main(args): out = model(doc) - for page, img in zip(out.pages, doc): - page.show(img, block=not args.noblock, interactive=not args.static) + for page in out.pages: + page.show(block=not args.noblock, interactive=not args.static) def parse_args(): diff --git a/tests/common/test_models.py b/tests/common/test_models.py index 72baf19c46..fe59f7a3ce 100644 --- a/tests/common/test_models.py +++ b/tests/common/test_models.py @@ -6,7 +6,7 @@ import requests from doctr.io import reader -from doctr.models._utils import estimate_orientation, get_bitmap_angle, get_language, invert_data_structure +from doctr.models._utils import estimate_orientation, get_language, invert_data_structure from doctr.utils import geometry @@ -23,22 +23,32 @@ def mock_image(tmpdir_factory): @pytest.fixture(scope="function") def mock_bitmap(mock_image): - bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY)) + bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY) / 255.0) + bitmap = np.expand_dims(bitmap, axis=-1) return bitmap -def test_get_bitmap_angle(mock_bitmap): - angle = get_bitmap_angle(mock_bitmap) - assert abs(angle - 30.0) < 1.0 +def test_estimate_orientation(mock_image, mock_bitmap, mock_tilted_payslip): + assert estimate_orientation(mock_image * 0) == 0 + # test binarized image + angle = estimate_orientation(mock_bitmap) + assert abs(angle - 30.0) < 1.0 -def test_estimate_orientation(mock_bitmap): - assert estimate_orientation(mock_bitmap * 0) == 0 + angle = estimate_orientation(mock_bitmap * 255) + assert abs(angle - 30.0) < 1.0 - angle = estimate_orientation(mock_bitmap) + angle = estimate_orientation(mock_image) assert abs(angle - 30.0) < 1.0 - rotated = geometry.rotate_image(mock_bitmap, -angle) + rotated = geometry.rotate_image(mock_image, -angle) + angle_rotated = estimate_orientation(rotated) + assert abs(angle_rotated) < 1.0 + + mock_tilted_payslip = reader.read_img_as_numpy(mock_tilted_payslip) + assert (estimate_orientation(mock_tilted_payslip) - 30.0) < 1.0 + + rotated = geometry.rotate_image(mock_tilted_payslip, -30, expand=True) angle_rotated = estimate_orientation(rotated) assert abs(angle_rotated) < 1.0 diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 39eae65168..8dac82d436 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -95,9 +95,13 @@ def test_detection_zoo(arch_name): input_tensor = input_tensor.cuda() with torch.no_grad(): - out = predictor(input_tensor) + out, seq_maps = predictor(input_tensor, return_maps=True) assert all(isinstance(boxes, dict) for boxes in out) assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out) + assert all(isinstance(seq_map, np.ndarray) for seq_map in seq_maps) + assert all(seq_map.shape[:2] == (1024, 1024) for seq_map in seq_maps) + # check that all values in the seq_maps are between 0 and 1 + assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps) def test_erode(): diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py index 188d2e7e01..18fc3bbe49 100644 --- a/tests/tensorflow/test_models_detection_tf.py +++ b/tests/tensorflow/test_models_detection_tf.py @@ -146,9 +146,13 @@ def test_detection_zoo(arch_name): # object check assert isinstance(predictor, DetectionPredictor) input_tensor = tf.random.uniform(shape=[2, 1024, 1024, 3], minval=0, maxval=1) - out = predictor(input_tensor) + out, seq_maps = predictor(input_tensor, return_maps=True) assert all(isinstance(boxes, dict) for boxes in out) assert all(isinstance(boxes[CLASS_NAME], np.ndarray) and boxes[CLASS_NAME].shape[1] == 5 for boxes in out) + assert all(isinstance(seq_map, np.ndarray) for seq_map in seq_maps) + assert all(seq_map.shape[:2] == (1024, 1024) for seq_map in seq_maps) + # check that all values in the seq_maps are between 0 and 1 + assert all((seq_map >= 0).all() and (seq_map <= 1).all() for seq_map in seq_maps) def test_detection_zoo_error():