Skip to content

Commit

Permalink
Update doughnut postprocessing (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Nov 14, 2024
1 parent 7e06db2 commit a31cd94
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 340 deletions.
22 changes: 15 additions & 7 deletions src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import pypdfium2 as pdfium
import tritonclient.grpc as grpcclient

from nv_ingest.extraction_workflows.pdf import doughnut_utils
from nv_ingest.schemas.metadata_schema import AccessLevelEnum
from nv_ingest.schemas.metadata_schema import ContentSubtypeEnum
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
Expand All @@ -39,6 +38,7 @@
from nv_ingest.util.exception_handlers.pdf import pdfium_exception_handler
from nv_ingest.util.image_processing.transforms import crop_image
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim import doughnut as doughnut_utils
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import LatexTable
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
Expand All @@ -50,6 +50,9 @@

DOUGHNUT_GRPC_TRITON = os.environ.get("DOUGHNUT_GRPC_TRITON", "triton:8001")
DEFAULT_BATCH_SIZE = 16
DEFAULT_RENDER_DPI = 300
DEFAULT_MAX_WIDTH = 1024
DEFAULT_MAX_HEIGHT = 1280


# Define a helper function to use doughnut to extract text from a base64 encoded bytestram PDF
Expand Down Expand Up @@ -165,7 +168,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
txt = doughnut_utils.postprocess_text(txt, cls)

if extract_images and identify_nearby_objects:
bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset)
bbox = doughnut_utils.reverse_transform_bbox(
bbox=bbox,
bbox_offset=bbox_offset,
original_width=DEFAULT_MAX_WIDTH,
original_height=DEFAULT_MAX_HEIGHT,
)
page_nearby_blocks["text"]["content"].append(txt)
page_nearby_blocks["text"]["bbox"].append(bbox)

Expand All @@ -182,8 +190,8 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table

elif extract_images and (cls == "Picture"):
if page_image is None:
scale_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT)
padding_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT)
scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT)
padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT)
page_image, *_ = pdfium_pages_to_numpy(
[pages[page_idx]], scale_tuple=scale_tuple, padding_tuple=padding_tuple
)
Expand Down Expand Up @@ -280,9 +288,9 @@ def preprocess_and_send_requests(
if not batch:
return []

render_dpi = 300
scale_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT)
padding_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT)
render_dpi = DEFAULT_RENDER_DPI
scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT)
padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT)

page_images, bbox_offsets = pdfium_pages_to_numpy(
batch, render_dpi=render_dpi, scale_tuple=scale_tuple, padding_tuple=padding_tuple
Expand Down
172 changes: 172 additions & 0 deletions src/nv_ingest/util/nim/doughnut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import re
from math import ceil
from math import floor
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np

from nv_ingest.util.image_processing.transforms import numpy_to_base64

ACCEPTED_TEXT_CLASSES = set(
[
"Text",
"Title",
"Section-header",
"List-item",
"TOC",
"Bibliography",
"Formula",
"Page-header",
"Page-footer",
"Caption",
"Footnote",
"Floating-text",
]
)
ACCEPTED_TABLE_CLASSES = set(
[
"Table",
]
)
ACCEPTED_IMAGE_CLASSES = set(
[
"Picture",
]
)
ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES

_re_extract_class_bbox = re.compile(
r"<x_(\d+)><y_(\d+)>((?:|.(?:(?<!<x_\d)(?<!<y_\d)(?<!<class_[A-Za-z0-9]).)*))<x_(\d+)><y_(\d+)><class_([A-Za-z0-9\-]+)>", # noqa: E501
re.MULTILINE | re.DOTALL,
)

logger = logging.getLogger(__name__)


def extract_classes_bboxes(text: str) -> Tuple[List[str], List[Tuple[int, int, int, int]], List[str]]:
classes: List[str] = []
bboxes: List[Tuple[int, int, int, int]] = []
texts: List[str] = []

last_end = 0

for m in _re_extract_class_bbox.finditer(text):
start, end = m.span()

# [Bad box] Add the non-match chunk (text between the last match and the current match)
if start > last_end:
bad_text = text[last_end:start].strip()
classes.append("Bad-box")
bboxes.append((0, 0, 0, 0))
texts.append(bad_text)

last_end = end

x1, y1, text, x2, y2, cls = m.groups()

bbox = tuple(map(int, (x1, y1, x2, y2)))

# [Bad box] check if the class is a valid class.
if cls not in ACCEPTED_CLASSES:
logger.debug(f"Dropped a bad box: invalid class {cls} at {bbox}.")
classes.append("Bad-box")
bboxes.append(bbox)
texts.append(text)
continue

# Drop bad box: drop if the box is invalid.
if (bbox[0] >= bbox[2]) or (bbox[1] >= bbox[3]):
logger.debug(f"Dropped a bad box: invalid box {cls} at {bbox}.")
classes.append("Bad-box")
bboxes.append(bbox)
texts.append(text)
continue

classes.append(cls)
bboxes.append(bbox)
texts.append(text)

if last_end < len(text):
bad_text = text[last_end:].strip()
if len(bad_text) > 0:
classes.append("Bad-box")
bboxes.append((0, 0, 0, 0))
texts.append(bad_text)

return classes, bboxes, texts


def _fix_dots(m):
# Remove spaces between dots.
s = m.group(0)
return s.startswith(" ") * " " + min(5, s.count(".")) * "." + s.endswith(" ") * " "


def strip_markdown_formatting(text):
# Remove headers (e.g., # Header, ## Header, ### Header)
text = re.sub(r"^(#+)\s*(.*)", r"\2", text, flags=re.MULTILINE)

# Remove bold formatting (e.g., **bold text** or __bold text__)
text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
text = re.sub(r"__(.*?)__", r"\1", text)

# Remove italic formatting (e.g., *italic text* or _italic text_)
text = re.sub(r"\*(.*?)\*", r"\1", text)
text = re.sub(r"_(.*?)_", r"\1", text)

# Remove strikethrough formatting (e.g., ~~strikethrough~~)
text = re.sub(r"~~(.*?)~~", r"\1", text)

# Remove list items (e.g., - item, * item, 1. item)
text = re.sub(r"^\s*([-*+]|[0-9]+\.)\s+", "", text, flags=re.MULTILINE)

# Remove hyperlinks (e.g., [link text](http://example.com))
text = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text)

# Remove inline code (e.g., `code`)
text = re.sub(r"`(.*?)`", r"\1", text)

# Remove blockquotes (e.g., > quote)
text = re.sub(r"^\s*>\s*(.*)", r"\1", text, flags=re.MULTILINE)

# Remove multiple newlines
text = re.sub(r"\n{3,}", "\n\n", text)

# Limit dots sequences to max 5 dots
text = re.sub(r"(?:\s*\.\s*){3,}", _fix_dots, text, flags=re.DOTALL)

return text


def reverse_transform_bbox(
bbox: Tuple[int, int, int, int],
bbox_offset: Tuple[int, int],
original_width: int,
original_height: int,
) -> Tuple[int, int, int, int]:
width_ratio = (original_width - 2 * bbox_offset[0]) / original_width
height_ratio = (original_height - 2 * bbox_offset[1]) / original_height
w1, h1, w2, h2 = bbox
w1 = int((w1 - bbox_offset[0]) / width_ratio)
h1 = int((h1 - bbox_offset[1]) / height_ratio)
w2 = int((w2 - bbox_offset[0]) / width_ratio)
h2 = int((h2 - bbox_offset[1]) / height_ratio)

return (w1, h1, w2, h2)


def postprocess_text(txt: str, cls: str):
if cls in ACCEPTED_CLASSES:
txt = txt.replace("<tbc>", "").strip() # remove <tbc> tokens (continued paragraphs)
txt = strip_markdown_formatting(txt)
else:
txt = ""

return txt
Loading

0 comments on commit a31cd94

Please sign in to comment.