From a31cd9464839850af46ca9b4adce5ae20054d6bc Mon Sep 17 00:00:00 2001 From: Edward Kim <109497216+edknv@users.noreply.github.com> Date: Thu, 14 Nov 2024 01:47:16 -0800 Subject: [PATCH] Update doughnut postprocessing (#223) --- .../pdf/doughnut_helper.py | 22 +- src/nv_ingest/util/nim/doughnut.py | 172 +++++++++ .../pdf/test_doughnut_utils.py | 333 ------------------ tests/nv_ingest/util/nim/test_doughnut.py | 143 ++++++++ 4 files changed, 330 insertions(+), 340 deletions(-) create mode 100644 src/nv_ingest/util/nim/doughnut.py delete mode 100644 tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py create mode 100644 tests/nv_ingest/util/nim/test_doughnut.py diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index bed95c97..8929f0fe 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -28,7 +28,6 @@ import pypdfium2 as pdfium import tritonclient.grpc as grpcclient -from nv_ingest.extraction_workflows.pdf import doughnut_utils from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import ContentSubtypeEnum from nv_ingest.schemas.metadata_schema import ContentTypeEnum @@ -39,6 +38,7 @@ from nv_ingest.util.exception_handlers.pdf import pdfium_exception_handler from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 +from nv_ingest.util.nim import doughnut as doughnut_utils from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import LatexTable from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image @@ -50,6 +50,9 @@ DOUGHNUT_GRPC_TRITON = os.environ.get("DOUGHNUT_GRPC_TRITON", "triton:8001") DEFAULT_BATCH_SIZE = 16 +DEFAULT_RENDER_DPI = 300 +DEFAULT_MAX_WIDTH = 1024 +DEFAULT_MAX_HEIGHT = 1280 # Define a helper function to use doughnut to extract text from a base64 encoded bytestram PDF @@ -165,7 +168,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table txt = doughnut_utils.postprocess_text(txt, cls) if extract_images and identify_nearby_objects: - bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset) + bbox = doughnut_utils.reverse_transform_bbox( + bbox=bbox, + bbox_offset=bbox_offset, + original_width=DEFAULT_MAX_WIDTH, + original_height=DEFAULT_MAX_HEIGHT, + ) page_nearby_blocks["text"]["content"].append(txt) page_nearby_blocks["text"]["bbox"].append(bbox) @@ -182,8 +190,8 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table elif extract_images and (cls == "Picture"): if page_image is None: - scale_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) - padding_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) + scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) + padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) page_image, *_ = pdfium_pages_to_numpy( [pages[page_idx]], scale_tuple=scale_tuple, padding_tuple=padding_tuple ) @@ -280,9 +288,9 @@ def preprocess_and_send_requests( if not batch: return [] - render_dpi = 300 - scale_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) - padding_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) + render_dpi = DEFAULT_RENDER_DPI + scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) + padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) page_images, bbox_offsets = pdfium_pages_to_numpy( batch, render_dpi=render_dpi, scale_tuple=scale_tuple, padding_tuple=padding_tuple diff --git a/src/nv_ingest/util/nim/doughnut.py b/src/nv_ingest/util/nim/doughnut.py new file mode 100644 index 00000000..cffdbce8 --- /dev/null +++ b/src/nv_ingest/util/nim/doughnut.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +from math import ceil +from math import floor +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np + +from nv_ingest.util.image_processing.transforms import numpy_to_base64 + +ACCEPTED_TEXT_CLASSES = set( + [ + "Text", + "Title", + "Section-header", + "List-item", + "TOC", + "Bibliography", + "Formula", + "Page-header", + "Page-footer", + "Caption", + "Footnote", + "Floating-text", + ] +) +ACCEPTED_TABLE_CLASSES = set( + [ + "Table", + ] +) +ACCEPTED_IMAGE_CLASSES = set( + [ + "Picture", + ] +) +ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES + +_re_extract_class_bbox = re.compile( + r"((?:|.(?:(?", # noqa: E501 + re.MULTILINE | re.DOTALL, +) + +logger = logging.getLogger(__name__) + + +def extract_classes_bboxes(text: str) -> Tuple[List[str], List[Tuple[int, int, int, int]], List[str]]: + classes: List[str] = [] + bboxes: List[Tuple[int, int, int, int]] = [] + texts: List[str] = [] + + last_end = 0 + + for m in _re_extract_class_bbox.finditer(text): + start, end = m.span() + + # [Bad box] Add the non-match chunk (text between the last match and the current match) + if start > last_end: + bad_text = text[last_end:start].strip() + classes.append("Bad-box") + bboxes.append((0, 0, 0, 0)) + texts.append(bad_text) + + last_end = end + + x1, y1, text, x2, y2, cls = m.groups() + + bbox = tuple(map(int, (x1, y1, x2, y2))) + + # [Bad box] check if the class is a valid class. + if cls not in ACCEPTED_CLASSES: + logger.debug(f"Dropped a bad box: invalid class {cls} at {bbox}.") + classes.append("Bad-box") + bboxes.append(bbox) + texts.append(text) + continue + + # Drop bad box: drop if the box is invalid. + if (bbox[0] >= bbox[2]) or (bbox[1] >= bbox[3]): + logger.debug(f"Dropped a bad box: invalid box {cls} at {bbox}.") + classes.append("Bad-box") + bboxes.append(bbox) + texts.append(text) + continue + + classes.append(cls) + bboxes.append(bbox) + texts.append(text) + + if last_end < len(text): + bad_text = text[last_end:].strip() + if len(bad_text) > 0: + classes.append("Bad-box") + bboxes.append((0, 0, 0, 0)) + texts.append(bad_text) + + return classes, bboxes, texts + + +def _fix_dots(m): + # Remove spaces between dots. + s = m.group(0) + return s.startswith(" ") * " " + min(5, s.count(".")) * "." + s.endswith(" ") * " " + + +def strip_markdown_formatting(text): + # Remove headers (e.g., # Header, ## Header, ### Header) + text = re.sub(r"^(#+)\s*(.*)", r"\2", text, flags=re.MULTILINE) + + # Remove bold formatting (e.g., **bold text** or __bold text__) + text = re.sub(r"\*\*(.*?)\*\*", r"\1", text) + text = re.sub(r"__(.*?)__", r"\1", text) + + # Remove italic formatting (e.g., *italic text* or _italic text_) + text = re.sub(r"\*(.*?)\*", r"\1", text) + text = re.sub(r"_(.*?)_", r"\1", text) + + # Remove strikethrough formatting (e.g., ~~strikethrough~~) + text = re.sub(r"~~(.*?)~~", r"\1", text) + + # Remove list items (e.g., - item, * item, 1. item) + text = re.sub(r"^\s*([-*+]|[0-9]+\.)\s+", "", text, flags=re.MULTILINE) + + # Remove hyperlinks (e.g., [link text](http://example.com)) + text = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text) + + # Remove inline code (e.g., `code`) + text = re.sub(r"`(.*?)`", r"\1", text) + + # Remove blockquotes (e.g., > quote) + text = re.sub(r"^\s*>\s*(.*)", r"\1", text, flags=re.MULTILINE) + + # Remove multiple newlines + text = re.sub(r"\n{3,}", "\n\n", text) + + # Limit dots sequences to max 5 dots + text = re.sub(r"(?:\s*\.\s*){3,}", _fix_dots, text, flags=re.DOTALL) + + return text + + +def reverse_transform_bbox( + bbox: Tuple[int, int, int, int], + bbox_offset: Tuple[int, int], + original_width: int, + original_height: int, +) -> Tuple[int, int, int, int]: + width_ratio = (original_width - 2 * bbox_offset[0]) / original_width + height_ratio = (original_height - 2 * bbox_offset[1]) / original_height + w1, h1, w2, h2 = bbox + w1 = int((w1 - bbox_offset[0]) / width_ratio) + h1 = int((h1 - bbox_offset[1]) / height_ratio) + w2 = int((w2 - bbox_offset[0]) / width_ratio) + h2 = int((h2 - bbox_offset[1]) / height_ratio) + + return (w1, h1, w2, h2) + + +def postprocess_text(txt: str, cls: str): + if cls in ACCEPTED_CLASSES: + txt = txt.replace("", "").strip() # remove tokens (continued paragraphs) + txt = strip_markdown_formatting(txt) + else: + txt = "" + + return txt diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py deleted file mode 100644 index 7b0c8bfd..00000000 --- a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import PIL -import pytest - -from nv_ingest.extraction_workflows.pdf.doughnut_utils import convert_mmd_to_plain_text_ours -from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image -from nv_ingest.extraction_workflows.pdf.doughnut_utils import extract_classes_bboxes -from nv_ingest.extraction_workflows.pdf.doughnut_utils import pad_image -from nv_ingest.extraction_workflows.pdf.doughnut_utils import reverse_transform_bbox -from nv_ingest.extraction_workflows.pdf.doughnut_utils import postprocess_text - - -def test_pad_image_same_size(): - array = np.ones((100, 100, 3), dtype=np.uint8) - padded_array, (pad_width, pad_height) = pad_image(array, 100, 100) - - assert np.array_equal(padded_array, array) - assert pad_width == 0 - assert pad_height == 0 - - -def test_pad_image_smaller_size(): - array = np.ones((50, 50, 3), dtype=np.uint8) - padded_array, (pad_width, pad_height) = pad_image(array, 100, 100) - - assert padded_array.shape == (100, 100, 3) - assert pad_width == (100 - 50) // 2 - assert pad_height == (100 - 50) // 2 - assert np.array_equal(padded_array[pad_height : pad_height + 50, pad_width : pad_width + 50], array) # noqa: E203 - - -def test_pad_image_width_exceeds_target(): - array = np.ones((50, 150, 3), dtype=np.uint8) - with pytest.raises(ValueError, match="Image array is too large"): - pad_image(array, 100, 100) - - -def test_pad_image_height_exceeds_target(): - array = np.ones((150, 50, 3), dtype=np.uint8) - with pytest.raises(ValueError, match="Image array is too large"): - pad_image(array, 100, 100) - - -def test_pad_image_with_non_default_target(): - array = np.ones((60, 60, 3), dtype=np.uint8) - target_width = 80 - target_height = 80 - padded_array, (pad_width, pad_height) = pad_image(array, target_width, target_height) - - assert padded_array.shape == (target_height, target_width, 3) - assert pad_width == (target_width - 60) // 2 - assert pad_height == (target_height - 60) // 2 - assert np.array_equal(padded_array[pad_height : pad_height + 60, pad_width : pad_width + 60], array) # noqa: E203 - - -def test_extract_classes_bboxes_simple(): - text = "text1" - expected_classes = ["A"] - expected_bboxes = [(10, 20, 30, 40)] - expected_texts = ["text1"] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_multiple(): - text = "text1\ntext2" - expected_classes = ["A", "B"] - expected_bboxes = [(10, 20, 30, 40), (50, 60, 70, 80)] - expected_texts = ["text1", "text2"] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_no_match(): - text = "This text does not match the pattern" - expected_classes = [] - expected_bboxes = [] - expected_texts = [] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_different_format(): - text = "sample" - expected_classes = ["test"] - expected_bboxes = [(1, 2, 3, 4)] - expected_texts = ["sample"] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_empty_input(): - text = "" - expected_classes = [] - expected_bboxes = [] - expected_texts = [] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_convert_mmd_to_plain_text_ours_headers(): - mmd_text = "## Header" - expected = "Header" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_bold_italic(): - mmd_text = "This is **bold** and *italic* text." - expected = "This is bold and italic text." - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_inline_math(): - mmd_text = "This is a formula: \\(E=mc^2\\)" - expected_with_math = "This is a formula: E=mc^2" - expected_without_math = "This is a formula:" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected_with_math - assert convert_mmd_to_plain_text_ours(mmd_text, remove_inline_math=True) == expected_without_math - - -def test_convert_mmd_to_plain_text_ours_lists_tables(): - mmd_text = "* List item\n\\begin{table}content\\end{table}" - expected = "List item" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_code_blocks_equations(): - mmd_text = "```\ncode block\n```\n\\[ equation \\]" - expected = "" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_special_chars(): - mmd_text = "Backslash \\ should be removed." - expected = "Backslash should be removed." - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_mixed_content(): - mmd_text = """ - ## Header - - This is **bold** and *italic* text with a formula: \\(E=mc^2\\). - - \\\[ equation \\\] - - \\begin{table}content\\end{table} - """ # noqa: W605 - expected = """ - Header - - This is bold and italic text with a formula: E=mc^2. - - """.strip() # noqa: W605 - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def create_test_image(width: int, height: int, color: tuple = (255, 0, 0)) -> PIL.Image: - """Helper function to create a solid color image for testing.""" - image = PIL.Image.new("RGB", (width, height), color) - return image - - -# def test_pymupdf_page_to_numpy_array_simple(): -# mock_page = Mock(spec=fitz.Page) -# mock_pixmap = Mock() -# mock_pixmap.pil_tobytes.return_value = b"fake image data" -# mock_page.get_pixmap.return_value = mock_pixmap -# -# with patch("PIL.Image.open", return_value=PIL.Image.new("RGB", (100, 100))): -# image, offset = pymupdf_page_to_numpy_array(mock_page, target_width=100, target_height=100) -# -# assert isinstance(image, np.ndarray) -# assert image.shape == (1280, 1024, 3) -# assert isinstance(offset, tuple) -# assert len(offset) == 2 -# assert all(isinstance(x, int) for x in offset) -# mock_page.get_pixmap.assert_called_once_with(dpi=300) -# mock_pixmap.pil_tobytes.assert_called_once_with(format="PNG") - - -# def test_pymupdf_page_to_numpy_array_different_dpi(): -# mock_page = Mock(spec=fitz.Page) -# mock_pixmap = Mock() -# mock_pixmap.pil_tobytes.return_value = b"fake image data" -# mock_page.get_pixmap.return_value = mock_pixmap - - -def test_crop_image_valid_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (10, 10, 50, 50) - result = crop_image(array, bbox) - assert result is not None - assert isinstance(result, str) - - -def test_crop_image_partial_outside_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (90, 90, 110, 110) - result = crop_image(array, bbox) - assert result is not None - assert isinstance(result, str) - - -def test_crop_image_completely_outside_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (110, 110, 120, 120) - result = crop_image(array, bbox) - assert result is None - - -def test_crop_image_zero_area_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (50, 50, 50, 50) - result = crop_image(array, bbox) - assert result is None - - -def test_crop_image_different_format(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (10, 10, 50, 50) - result = crop_image(array, bbox, format="JPEG") - assert result is not None - assert isinstance(result, str) - - -def test_reverse_transform_bbox_no_offset(): - bbox = (10, 20, 30, 40) - bbox_offset = (0, 0) - expected_bbox = (10, 20, 30, 40) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_with_offset(): - bbox = (20, 30, 40, 50) - bbox_offset = (10, 10) - expected_bbox = (12, 25, 37, 50) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_with_large_offset(): - bbox = (60, 80, 90, 100) - bbox_offset = (20, 30) - width_ratio = (100 - 2 * bbox_offset[0]) / 100 - height_ratio = (100 - 2 * bbox_offset[1]) / 100 - expected_bbox = ( - int((60 - bbox_offset[0]) / width_ratio), - int((80 - bbox_offset[1]) / height_ratio), - int((90 - bbox_offset[0]) / width_ratio), - int((100 - bbox_offset[1]) / height_ratio), - ) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_custom_dimensions(): - bbox = (15, 25, 35, 45) - bbox_offset = (5, 5) - original_width = 200 - original_height = 200 - width_ratio = (original_width - 2 * bbox_offset[0]) / original_width - height_ratio = (original_height - 2 * bbox_offset[1]) / original_height - expected_bbox = ( - int((15 - bbox_offset[0]) / width_ratio), - int((25 - bbox_offset[1]) / height_ratio), - int((35 - bbox_offset[0]) / width_ratio), - int((45 - bbox_offset[1]) / height_ratio), - ) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_zero_dimension(): - bbox = (10, 10, 20, 20) - bbox_offset = (0, 0) - original_width = 0 - original_height = 0 - with pytest.raises(ZeroDivisionError): - reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) - - -def test_postprocess_text_with_unaccepted_class(): - # Input text that should not be processed - txt = "This text should not be processed" - cls = "InvalidClass" # Not in ACCEPTED_CLASSES - - result = postprocess_text(txt, cls) - - assert result == "" - - -def test_postprocess_text_removes_tbc_and_processes_text(): - # Input text containing "" - txt = "Some text" - cls = "Title" # An accepted class - - expected_output = "Some text" - - result = postprocess_text(txt, cls) - - assert result == expected_output - - -def test_postprocess_text_no_tbc_but_accepted_class(): - # Input text without "" - txt = "This is a test **without** tbc" - cls = "Section-header" # An accepted class - - expected_output = "This is a test without tbc" - - result = postprocess_text(txt, cls) - - assert result == expected_output diff --git a/tests/nv_ingest/util/nim/test_doughnut.py b/tests/nv_ingest/util/nim/test_doughnut.py new file mode 100644 index 00000000..656acf8f --- /dev/null +++ b/tests/nv_ingest/util/nim/test_doughnut.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nv_ingest.util.nim.doughnut import extract_classes_bboxes +from nv_ingest.util.nim.doughnut import postprocess_text +from nv_ingest.util.nim.doughnut import reverse_transform_bbox +from nv_ingest.util.nim.doughnut import strip_markdown_formatting + + +def test_reverse_transform_bbox_no_offset(): + bbox = (10, 20, 30, 40) + bbox_offset = (0, 0) + expected_bbox = (10, 20, 30, 40) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_with_offset(): + bbox = (20, 30, 40, 50) + bbox_offset = (10, 10) + expected_bbox = (12, 25, 37, 50) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_with_large_offset(): + bbox = (60, 80, 90, 100) + bbox_offset = (20, 30) + width_ratio = (100 - 2 * bbox_offset[0]) / 100 + height_ratio = (100 - 2 * bbox_offset[1]) / 100 + expected_bbox = ( + int((60 - bbox_offset[0]) / width_ratio), + int((80 - bbox_offset[1]) / height_ratio), + int((90 - bbox_offset[0]) / width_ratio), + int((100 - bbox_offset[1]) / height_ratio), + ) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_custom_dimensions(): + bbox = (15, 25, 35, 45) + bbox_offset = (5, 5) + original_width = 200 + original_height = 200 + width_ratio = (original_width - 2 * bbox_offset[0]) / original_width + height_ratio = (original_height - 2 * bbox_offset[1]) / original_height + expected_bbox = ( + int((15 - bbox_offset[0]) / width_ratio), + int((25 - bbox_offset[1]) / height_ratio), + int((35 - bbox_offset[0]) / width_ratio), + int((45 - bbox_offset[1]) / height_ratio), + ) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_zero_dimension(): + bbox = (10, 10, 20, 20) + bbox_offset = (0, 0) + original_width = 0 + original_height = 0 + with pytest.raises(ZeroDivisionError): + reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) + + +def test_postprocess_text_with_unaccepted_class(): + # Input text that should not be processed + txt = "This text should not be processed" + cls = "InvalidClass" # Not in ACCEPTED_CLASSES + + result = postprocess_text(txt, cls) + + assert result == "" + + +def test_postprocess_text_removes_tbc_and_processes_text(): + # Input text containing "" + txt = "Some text" + cls = "Title" # An accepted class + + expected_output = "Some text" + + result = postprocess_text(txt, cls) + + assert result == expected_output + + +def test_postprocess_text_no_tbc_but_accepted_class(): + # Input text without "" + txt = "This is a test **without** tbc" + cls = "Section-header" # An accepted class + + expected_output = "This is a test without tbc" + + result = postprocess_text(txt, cls) + + assert result == expected_output + + +@pytest.mark.parametrize( + "input_text, expected_classes, expected_bboxes, expected_texts", + [ + ("Sample text", ["Text"], [(10, 20, 30, 40)], ["Sample text"]), + ( + "Invalid text ", + ["Bad-box", "Bad-box"], + [(0, 0, 0, 0), (10, 20, 30, 40)], + ["Invalid text", ""], + ), + ("Header content", ["Title"], [(15, 25, 35, 45)], ["Header content"]), + ("Overlapping box", ["Bad-box"], [(5, 10, 5, 10)], ["Overlapping box"]), + ], +) +def test_extract_classes_bboxes(input_text, expected_classes, expected_bboxes, expected_texts): + classes, bboxes, texts = extract_classes_bboxes(input_text) + assert classes == expected_classes + assert bboxes == expected_bboxes + assert texts == expected_texts + + +# Test cases for strip_markdown_formatting +@pytest.mark.parametrize( + "input_text, expected_output", + [ + ("# Header\n**Bold text**\n*Italic*", "Header\nBold text\nItalic"), + ("~~Strikethrough~~", "Strikethrough"), + ("[Link](http://example.com)", "Link"), + ("`inline code`", "inline code"), + ("> Blockquote", "Blockquote"), + ("Normal text\n\n\nMultiple newlines", "Normal text\n\nMultiple newlines"), + ("Dot sequence...... more text", "Dot sequence..... more text"), + ], +) +def test_strip_markdown_formatting(input_text, expected_output): + assert strip_markdown_formatting(input_text) == expected_output