diff --git a/client/client_examples/examples/python_client_usage.ipynb b/client/client_examples/examples/python_client_usage.ipynb index dab70631..99185b1e 100644 --- a/client/client_examples/examples/python_client_usage.ipynb +++ b/client/client_examples/examples/python_client_usage.ipynb @@ -630,6 +630,7 @@ "job_spec.add_task(extract_task)\n", "job_spec.add_task(dedup_task)\n", "job_spec.add_task(filter_task)\n", + "job_spec.add_task(split_task)\n", "job_spec.add_task(store_task)\n", "job_spec.add_task(embed_task)\n", "job_spec.add_task(vdb_upload_task)" diff --git a/client/src/nv_ingest_client/cli/util/processing.py b/client/src/nv_ingest_client/cli/util/processing.py index b452bc10..acbb3d23 100644 --- a/client/src/nv_ingest_client/cli/util/processing.py +++ b/client/src/nv_ingest_client/cli/util/processing.py @@ -652,7 +652,7 @@ def create_and_process_jobs( failed_jobs.append(f"{job_id}::{source_name}") except RuntimeError as e: source_name = job_id_map[job_id] - logger.error(f"Error while processing {job_id}({source_name}) {e}") + logger.error(f"Error while processing '{job_id}' - ({source_name}):\n{e}") failed_jobs.append(f"{job_id}::{source_name}") except Exception as e: traceback.print_exc() diff --git a/client/src/nv_ingest_client/client/client.py b/client/src/nv_ingest_client/client/client.py index ea8b4189..4899e1fc 100644 --- a/client/src/nv_ingest_client/client/client.py +++ b/client/src/nv_ingest_client/client/client.py @@ -88,6 +88,7 @@ def __init__( self._current_message_id = 0 self._job_states = {} + self._job_index_to_job_spec = {} self._message_client_hostname = message_client_hostname or "localhost" self._message_client_port = message_client_port or 7670 self._message_counter_id = msg_counter_id or "nv-ingest-message-id" @@ -177,9 +178,11 @@ def _add_single_job(self, job_spec: JobSpec) -> str: return job_index - def add_job(self, job_spec: Union[BatchJobSpec, JobSpec]) -> str: + def add_job(self, job_spec: Union[BatchJobSpec, JobSpec]) -> Union[str, List[str]]: if isinstance(job_spec, JobSpec): job_index = self._add_single_job(job_spec) + self._job_index_to_job_spec[job_index] = job_spec + return job_index elif isinstance(job_spec, BatchJobSpec): job_indexes = [] @@ -187,6 +190,7 @@ def add_job(self, job_spec: Union[BatchJobSpec, JobSpec]) -> str: for job in job_specs: job_index = self._add_single_job(job) job_indexes.append(job_index) + self._job_index_to_job_spec[job_index] = job return job_indexes else: raise ValueError(f"Unexpected type: {type(job_spec)}") @@ -241,7 +245,8 @@ def create_job( extended_options=extended_options, ) - return self.add_job(job_spec) + job_id = self.add_job(job_spec) + return job_id def add_task(self, job_index: str, task: Task) -> None: job_state = self._get_and_check_job_state(job_index, required_state=JobStateEnum.PENDING) @@ -295,7 +300,8 @@ def _fetch_job_result(self, job_index: str, timeout: float = 100, data_only: boo """ try: - job_state = self._get_and_check_job_state(job_index, required_state=[JobStateEnum.SUBMITTED, JobStateEnum.SUBMITTED_ASYNC]) + job_state = self._get_and_check_job_state(job_index, required_state=[JobStateEnum.SUBMITTED, + JobStateEnum.SUBMITTED_ASYNC]) response = self._message_client.fetch_message(job_state.job_id, timeout) if response is not None: @@ -345,12 +351,12 @@ def _fetch_job_result_wait(self, job_id: str, timeout: float = 60, data_only: bo # This is the direct Python approach function for retrieving jobs which handles the timeouts directly # in the function itself instead of expecting the user to handle it themselves def fetch_job_result( - self, - job_ids: Union[str, List[str]], - timeout: float = 100, - max_retries: Optional[int] = None, - retry_delay: float = 1, - verbose: bool = False, + self, + job_ids: Union[str, List[str]], + timeout: float = 100, + max_retries: Optional[int] = None, + retry_delay: float = 1, + verbose: bool = False, ) -> List[Tuple[Optional[Dict], str]]: """ Fetches job results for multiple job IDs concurrently with individual timeouts and retry logic. @@ -410,14 +416,19 @@ def fetch_with_retries(job_id: str): try: result = handle_future_result(future, timeout=timeout) results.append(result.get("data")) + del self._job_index_to_job_spec[job_id] except concurrent.futures.TimeoutError: - logger.error(f"Timeout while fetching result for job ID {job_id}") + logger.error( + f"Timeout while fetching result for job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}") except json.JSONDecodeError as e: - logger.error(f"Decoding while processing job ID {job_id}: {e}") + logger.error( + f"Decoding while processing job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}") except RuntimeError as e: - logger.error(f"Error while processing job ID {job_id}: {e}") + logger.error( + f"Error while processing job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}") except Exception as e: - logger.error(f"Error while fetching result for job ID {job_id}: {e}") + logger.error( + f"Error while fetching result for job ID {job_id}: {self._job_index_to_job_spec[job_id].source_id}\n{e}") return results @@ -585,7 +596,7 @@ def submit_job_async(self, job_indices: Union[str, List[str]], job_queue_id: str return future_to_job_index - def create_jobs_for_batch(self, files_batch: List[str], tasks: Dict[str, Any]) -> List[JobSpec]: + def create_jobs_for_batch(self, files_batch: List[str], tasks: Dict[str, Any]) -> List[str]: """ Create and submit job specifications (JobSpecs) for a batch of files, returning the job IDs. This function takes a batch of files, processes each file to extract its content and type, diff --git a/client/src/nv_ingest_client/client/interface.py b/client/src/nv_ingest_client/client/interface.py index 7523054e..70cd0dc2 100644 --- a/client/src/nv_ingest_client/client/interface.py +++ b/client/src/nv_ingest_client/client/interface.py @@ -65,11 +65,11 @@ class Ingestor: """ def __init__( - self, - documents: Optional[List[str]] = None, - client: Optional[NvIngestClient] = None, - job_queue_id: str = DEFAULT_JOB_QUEUE_ID, - **kwargs, + self, + documents: Optional[List[str]] = None, + client: Optional[NvIngestClient] = None, + job_queue_id: str = DEFAULT_JOB_QUEUE_ID, + **kwargs, ): self._documents = documents or [] self._client = client @@ -83,6 +83,7 @@ def __init__( self._job_specs = None self._job_ids = None self._job_states = None + self._job_id_to_source_id = {} if self._check_files_local(): self._job_specs = BatchJobSpec(self._documents) @@ -242,6 +243,7 @@ def ingest_async(self, **kwargs: Any) -> Future: self._prepare_ingest_run() self._job_ids = self._client.add_job(self._job_specs) + future_to_job_id = self._client.submit_job_async(self._job_ids, self._job_queue_id, **kwargs) self._job_states = {job_id: self._client._get_and_check_job_state(job_id) for job_id in self._job_ids} @@ -300,8 +302,8 @@ def all_tasks(self) -> "Ingestor": .filter() \ .split() \ .embed() - # .store() \ - # .vdb_upload() + # .store() \ + # .vdb_upload() # fmt: on return self @@ -360,11 +362,13 @@ def extract(self, **kwargs: Any) -> "Ingestor": Ingestor Returns self for chaining. """ - extract_tables = kwargs.get("extract_tables", False) - extract_charts = kwargs.get("extract_charts", False) + extract_tables = kwargs.pop("extract_tables", True) + extract_charts = kwargs.pop("extract_charts", True) for document_type in self._job_specs.file_types: - extract_task = ExtractTask(document_type, **kwargs) + extract_task = ExtractTask( + document_type, extract_tables=extract_tables, extract_charts=extract_charts, **kwargs + ) self._job_specs.add_task(extract_task, document_type=document_type) if extract_tables is True: diff --git a/client/src/nv_ingest_client/primitives/tasks/__init__.py b/client/src/nv_ingest_client/primitives/tasks/__init__.py index 99fae562..8f32e6d8 100644 --- a/client/src/nv_ingest_client/primitives/tasks/__init__.py +++ b/client/src/nv_ingest_client/primitives/tasks/__init__.py @@ -3,12 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 from .caption import CaptionTask +from .chart_extraction import ChartExtractionTask from .dedup import DedupTask from .embed import EmbedTask from .extract import ExtractTask from .filter import FilterTask from .split import SplitTask from .store import StoreTask +from .table_extraction import TableExtractionTask from .task_base import Task from .task_base import TaskType from .task_base import is_valid_task_type @@ -17,10 +19,12 @@ __all__ = [ "CaptionTask", + "ChartExtractionTask", "ExtractTask", "is_valid_task_type", "SplitTask", "StoreTask", + "TableExtractionTask", "Task", "task_factory", "TaskType", diff --git a/client/src/nv_ingest_client/util/util.py b/client/src/nv_ingest_client/util/util.py index ebb6db3e..adb67a91 100644 --- a/client/src/nv_ingest_client/util/util.py +++ b/client/src/nv_ingest_client/util/util.py @@ -24,7 +24,6 @@ logger = logging.getLogger(__name__) - # pylint: disable=invalid-name # pylint: disable=missing-class-docstring # pylint: disable=logging-fstring-interpolation @@ -257,14 +256,25 @@ def check_ingest_result(json_payload: Dict) -> typing.Tuple[bool, str]: ) is_failed = json_payload.get("status", "") in "failed" - description = json_payload.get("description", "") + description = "" + if (is_failed): + try: + source_id = json_payload.get("data", [])[0].get("metadata", {}).get("source_metadata", {}).get( + "source_name", + "") + except Exception as e: + source_id = "" + + description = f"[{source_id}]: {json_payload.get('status', '')}\n" + + description += (json_payload.get("description", "")) # Look to see if we have any failure annotations to augment the description if is_failed and "annotations" in json_payload: for annot_id, value in json_payload["annotations"].items(): if "task_result" in value and value["task_result"] == "FAILURE": message = value.get("message", "Unknown") - description = f"\n↪ Event that caused this failure: {annot_id} -> {message}" + description += f"\n↪ Event that caused this failure: {annot_id} -> {message}" break return is_failed, description diff --git a/config/otel-collector-config.yaml b/config/otel-collector-config.yaml index cdfa8f32..c05a1fc2 100644 --- a/config/otel-collector-config.yaml +++ b/config/otel-collector-config.yaml @@ -19,6 +19,19 @@ exporters: processors: batch: + tail_sampling: + policies: [ + { + name: filter_http_url, + type: string_attribute, + string_attribute: { + key: http.route, + values: [ "/health/ready" ], + enabled_regex_matching: true, + invert_match: true + } + } + ] extensions: health_check: @@ -32,7 +45,7 @@ service: pipelines: traces: receivers: [otlp] - processors: [batch] + processors: [batch, tail_sampling] exporters: [zipkin, logging] metrics: receivers: [otlp] diff --git a/requirements.txt b/requirements.txt index 5362d5d5..11007c2b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -aiohttp==3.9.4 +aiohttp==3.10.11 charset-normalizer click opencv-python diff --git a/src/nv_ingest/api/main.py b/src/nv_ingest/api/main.py index 9beba4f6..d6ec468b 100644 --- a/src/nv_ingest/api/main.py +++ b/src/nv_ingest/api/main.py @@ -13,6 +13,7 @@ from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor @@ -22,7 +23,8 @@ from .v1.ingest import router as IngestApiRouter # Set up the tracer provider and add a processor for exporting traces -trace.set_tracer_provider(TracerProvider()) +resource = Resource(attributes={"service.name": "nv-ingest"}) +trace.set_tracer_provider(TracerProvider(resource=resource)) tracer = trace.get_tracer(__name__) otel_endpoint = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "otel-collector:4317") diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index bed95c97..8929f0fe 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -28,7 +28,6 @@ import pypdfium2 as pdfium import tritonclient.grpc as grpcclient -from nv_ingest.extraction_workflows.pdf import doughnut_utils from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import ContentSubtypeEnum from nv_ingest.schemas.metadata_schema import ContentTypeEnum @@ -39,6 +38,7 @@ from nv_ingest.util.exception_handlers.pdf import pdfium_exception_handler from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 +from nv_ingest.util.nim import doughnut as doughnut_utils from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import LatexTable from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image @@ -50,6 +50,9 @@ DOUGHNUT_GRPC_TRITON = os.environ.get("DOUGHNUT_GRPC_TRITON", "triton:8001") DEFAULT_BATCH_SIZE = 16 +DEFAULT_RENDER_DPI = 300 +DEFAULT_MAX_WIDTH = 1024 +DEFAULT_MAX_HEIGHT = 1280 # Define a helper function to use doughnut to extract text from a base64 encoded bytestram PDF @@ -165,7 +168,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table txt = doughnut_utils.postprocess_text(txt, cls) if extract_images and identify_nearby_objects: - bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset) + bbox = doughnut_utils.reverse_transform_bbox( + bbox=bbox, + bbox_offset=bbox_offset, + original_width=DEFAULT_MAX_WIDTH, + original_height=DEFAULT_MAX_HEIGHT, + ) page_nearby_blocks["text"]["content"].append(txt) page_nearby_blocks["text"]["bbox"].append(bbox) @@ -182,8 +190,8 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table elif extract_images and (cls == "Picture"): if page_image is None: - scale_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) - padding_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) + scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) + padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) page_image, *_ = pdfium_pages_to_numpy( [pages[page_idx]], scale_tuple=scale_tuple, padding_tuple=padding_tuple ) @@ -280,9 +288,9 @@ def preprocess_and_send_requests( if not batch: return [] - render_dpi = 300 - scale_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) - padding_tuple = (doughnut_utils.DEFAULT_MAX_WIDTH, doughnut_utils.DEFAULT_MAX_HEIGHT) + render_dpi = DEFAULT_RENDER_DPI + scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) + padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) page_images, bbox_offsets = pdfium_pages_to_numpy( batch, render_dpi=render_dpi, scale_tuple=scale_tuple, padding_tuple=padding_tuple diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index 9216c12f..9ea8747f 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -471,15 +471,18 @@ def pdfium_extractor( pdfium_config, trace_info=trace_info, ): - extracted_data.append( - construct_table_and_chart_metadata( - table_and_charts, - page_idx, - pdf_metadata.page_count, - source_metadata, - base_unified_metadata, + if (extract_tables and (table_and_charts.type_string == "table")) or ( + extract_charts and (table_and_charts.type_string == "chart") + ): + extracted_data.append( + construct_table_and_chart_metadata( + table_and_charts, + page_idx, + pdf_metadata.page_count, + source_metadata, + base_unified_metadata, + ) ) - ) logger.debug(f"Extracted {len(extracted_data)} items from PDF.") diff --git a/src/nv_ingest/modules/telemetry/otel_tracer.py b/src/nv_ingest/modules/telemetry/otel_tracer.py index 05be65ae..100d398a 100644 --- a/src/nv_ingest/modules/telemetry/otel_tracer.py +++ b/src/nv_ingest/modules/telemetry/otel_tracer.py @@ -83,7 +83,7 @@ def collect_timestamps(message): is_remote=True, trace_flags=TraceFlags(0x01), ) - parent_ctx = trace.set_span_in_context(NonRecordingSpan(span_context)) + parent_ctx = trace.set_span_in_context(span_context) parent_span = tracer.start_span(job_id, context=parent_ctx, start_time=start_time) create_span_with_timestamps(tracer, parent_span, message) diff --git a/src/nv_ingest/util/nim/doughnut.py b/src/nv_ingest/util/nim/doughnut.py new file mode 100644 index 00000000..cffdbce8 --- /dev/null +++ b/src/nv_ingest/util/nim/doughnut.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +from math import ceil +from math import floor +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np + +from nv_ingest.util.image_processing.transforms import numpy_to_base64 + +ACCEPTED_TEXT_CLASSES = set( + [ + "Text", + "Title", + "Section-header", + "List-item", + "TOC", + "Bibliography", + "Formula", + "Page-header", + "Page-footer", + "Caption", + "Footnote", + "Floating-text", + ] +) +ACCEPTED_TABLE_CLASSES = set( + [ + "Table", + ] +) +ACCEPTED_IMAGE_CLASSES = set( + [ + "Picture", + ] +) +ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES + +_re_extract_class_bbox = re.compile( + r"((?:|.(?:(?", # noqa: E501 + re.MULTILINE | re.DOTALL, +) + +logger = logging.getLogger(__name__) + + +def extract_classes_bboxes(text: str) -> Tuple[List[str], List[Tuple[int, int, int, int]], List[str]]: + classes: List[str] = [] + bboxes: List[Tuple[int, int, int, int]] = [] + texts: List[str] = [] + + last_end = 0 + + for m in _re_extract_class_bbox.finditer(text): + start, end = m.span() + + # [Bad box] Add the non-match chunk (text between the last match and the current match) + if start > last_end: + bad_text = text[last_end:start].strip() + classes.append("Bad-box") + bboxes.append((0, 0, 0, 0)) + texts.append(bad_text) + + last_end = end + + x1, y1, text, x2, y2, cls = m.groups() + + bbox = tuple(map(int, (x1, y1, x2, y2))) + + # [Bad box] check if the class is a valid class. + if cls not in ACCEPTED_CLASSES: + logger.debug(f"Dropped a bad box: invalid class {cls} at {bbox}.") + classes.append("Bad-box") + bboxes.append(bbox) + texts.append(text) + continue + + # Drop bad box: drop if the box is invalid. + if (bbox[0] >= bbox[2]) or (bbox[1] >= bbox[3]): + logger.debug(f"Dropped a bad box: invalid box {cls} at {bbox}.") + classes.append("Bad-box") + bboxes.append(bbox) + texts.append(text) + continue + + classes.append(cls) + bboxes.append(bbox) + texts.append(text) + + if last_end < len(text): + bad_text = text[last_end:].strip() + if len(bad_text) > 0: + classes.append("Bad-box") + bboxes.append((0, 0, 0, 0)) + texts.append(bad_text) + + return classes, bboxes, texts + + +def _fix_dots(m): + # Remove spaces between dots. + s = m.group(0) + return s.startswith(" ") * " " + min(5, s.count(".")) * "." + s.endswith(" ") * " " + + +def strip_markdown_formatting(text): + # Remove headers (e.g., # Header, ## Header, ### Header) + text = re.sub(r"^(#+)\s*(.*)", r"\2", text, flags=re.MULTILINE) + + # Remove bold formatting (e.g., **bold text** or __bold text__) + text = re.sub(r"\*\*(.*?)\*\*", r"\1", text) + text = re.sub(r"__(.*?)__", r"\1", text) + + # Remove italic formatting (e.g., *italic text* or _italic text_) + text = re.sub(r"\*(.*?)\*", r"\1", text) + text = re.sub(r"_(.*?)_", r"\1", text) + + # Remove strikethrough formatting (e.g., ~~strikethrough~~) + text = re.sub(r"~~(.*?)~~", r"\1", text) + + # Remove list items (e.g., - item, * item, 1. item) + text = re.sub(r"^\s*([-*+]|[0-9]+\.)\s+", "", text, flags=re.MULTILINE) + + # Remove hyperlinks (e.g., [link text](http://example.com)) + text = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text) + + # Remove inline code (e.g., `code`) + text = re.sub(r"`(.*?)`", r"\1", text) + + # Remove blockquotes (e.g., > quote) + text = re.sub(r"^\s*>\s*(.*)", r"\1", text, flags=re.MULTILINE) + + # Remove multiple newlines + text = re.sub(r"\n{3,}", "\n\n", text) + + # Limit dots sequences to max 5 dots + text = re.sub(r"(?:\s*\.\s*){3,}", _fix_dots, text, flags=re.DOTALL) + + return text + + +def reverse_transform_bbox( + bbox: Tuple[int, int, int, int], + bbox_offset: Tuple[int, int], + original_width: int, + original_height: int, +) -> Tuple[int, int, int, int]: + width_ratio = (original_width - 2 * bbox_offset[0]) / original_width + height_ratio = (original_height - 2 * bbox_offset[1]) / original_height + w1, h1, w2, h2 = bbox + w1 = int((w1 - bbox_offset[0]) / width_ratio) + h1 = int((h1 - bbox_offset[1]) / height_ratio) + w2 = int((w2 - bbox_offset[0]) / width_ratio) + h2 = int((h2 - bbox_offset[1]) / height_ratio) + + return (w1, h1, w2, h2) + + +def postprocess_text(txt: str, cls: str): + if cls in ACCEPTED_CLASSES: + txt = txt.replace("", "").strip() # remove tokens (continued paragraphs) + txt = strip_markdown_formatting(txt) + else: + txt = "" + + return txt diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py deleted file mode 100644 index 7b0c8bfd..00000000 --- a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_utils.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import numpy as np -import PIL -import pytest - -from nv_ingest.extraction_workflows.pdf.doughnut_utils import convert_mmd_to_plain_text_ours -from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image -from nv_ingest.extraction_workflows.pdf.doughnut_utils import extract_classes_bboxes -from nv_ingest.extraction_workflows.pdf.doughnut_utils import pad_image -from nv_ingest.extraction_workflows.pdf.doughnut_utils import reverse_transform_bbox -from nv_ingest.extraction_workflows.pdf.doughnut_utils import postprocess_text - - -def test_pad_image_same_size(): - array = np.ones((100, 100, 3), dtype=np.uint8) - padded_array, (pad_width, pad_height) = pad_image(array, 100, 100) - - assert np.array_equal(padded_array, array) - assert pad_width == 0 - assert pad_height == 0 - - -def test_pad_image_smaller_size(): - array = np.ones((50, 50, 3), dtype=np.uint8) - padded_array, (pad_width, pad_height) = pad_image(array, 100, 100) - - assert padded_array.shape == (100, 100, 3) - assert pad_width == (100 - 50) // 2 - assert pad_height == (100 - 50) // 2 - assert np.array_equal(padded_array[pad_height : pad_height + 50, pad_width : pad_width + 50], array) # noqa: E203 - - -def test_pad_image_width_exceeds_target(): - array = np.ones((50, 150, 3), dtype=np.uint8) - with pytest.raises(ValueError, match="Image array is too large"): - pad_image(array, 100, 100) - - -def test_pad_image_height_exceeds_target(): - array = np.ones((150, 50, 3), dtype=np.uint8) - with pytest.raises(ValueError, match="Image array is too large"): - pad_image(array, 100, 100) - - -def test_pad_image_with_non_default_target(): - array = np.ones((60, 60, 3), dtype=np.uint8) - target_width = 80 - target_height = 80 - padded_array, (pad_width, pad_height) = pad_image(array, target_width, target_height) - - assert padded_array.shape == (target_height, target_width, 3) - assert pad_width == (target_width - 60) // 2 - assert pad_height == (target_height - 60) // 2 - assert np.array_equal(padded_array[pad_height : pad_height + 60, pad_width : pad_width + 60], array) # noqa: E203 - - -def test_extract_classes_bboxes_simple(): - text = "text1" - expected_classes = ["A"] - expected_bboxes = [(10, 20, 30, 40)] - expected_texts = ["text1"] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_multiple(): - text = "text1\ntext2" - expected_classes = ["A", "B"] - expected_bboxes = [(10, 20, 30, 40), (50, 60, 70, 80)] - expected_texts = ["text1", "text2"] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_no_match(): - text = "This text does not match the pattern" - expected_classes = [] - expected_bboxes = [] - expected_texts = [] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_different_format(): - text = "sample" - expected_classes = ["test"] - expected_bboxes = [(1, 2, 3, 4)] - expected_texts = ["sample"] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_extract_classes_bboxes_empty_input(): - text = "" - expected_classes = [] - expected_bboxes = [] - expected_texts = [] - classes, bboxes, texts = extract_classes_bboxes(text) - assert classes == expected_classes - assert bboxes == expected_bboxes - assert texts == expected_texts - - -def test_convert_mmd_to_plain_text_ours_headers(): - mmd_text = "## Header" - expected = "Header" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_bold_italic(): - mmd_text = "This is **bold** and *italic* text." - expected = "This is bold and italic text." - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_inline_math(): - mmd_text = "This is a formula: \\(E=mc^2\\)" - expected_with_math = "This is a formula: E=mc^2" - expected_without_math = "This is a formula:" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected_with_math - assert convert_mmd_to_plain_text_ours(mmd_text, remove_inline_math=True) == expected_without_math - - -def test_convert_mmd_to_plain_text_ours_lists_tables(): - mmd_text = "* List item\n\\begin{table}content\\end{table}" - expected = "List item" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_code_blocks_equations(): - mmd_text = "```\ncode block\n```\n\\[ equation \\]" - expected = "" - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_special_chars(): - mmd_text = "Backslash \\ should be removed." - expected = "Backslash should be removed." - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def test_convert_mmd_to_plain_text_ours_mixed_content(): - mmd_text = """ - ## Header - - This is **bold** and *italic* text with a formula: \\(E=mc^2\\). - - \\\[ equation \\\] - - \\begin{table}content\\end{table} - """ # noqa: W605 - expected = """ - Header - - This is bold and italic text with a formula: E=mc^2. - - """.strip() # noqa: W605 - assert convert_mmd_to_plain_text_ours(mmd_text) == expected - - -def create_test_image(width: int, height: int, color: tuple = (255, 0, 0)) -> PIL.Image: - """Helper function to create a solid color image for testing.""" - image = PIL.Image.new("RGB", (width, height), color) - return image - - -# def test_pymupdf_page_to_numpy_array_simple(): -# mock_page = Mock(spec=fitz.Page) -# mock_pixmap = Mock() -# mock_pixmap.pil_tobytes.return_value = b"fake image data" -# mock_page.get_pixmap.return_value = mock_pixmap -# -# with patch("PIL.Image.open", return_value=PIL.Image.new("RGB", (100, 100))): -# image, offset = pymupdf_page_to_numpy_array(mock_page, target_width=100, target_height=100) -# -# assert isinstance(image, np.ndarray) -# assert image.shape == (1280, 1024, 3) -# assert isinstance(offset, tuple) -# assert len(offset) == 2 -# assert all(isinstance(x, int) for x in offset) -# mock_page.get_pixmap.assert_called_once_with(dpi=300) -# mock_pixmap.pil_tobytes.assert_called_once_with(format="PNG") - - -# def test_pymupdf_page_to_numpy_array_different_dpi(): -# mock_page = Mock(spec=fitz.Page) -# mock_pixmap = Mock() -# mock_pixmap.pil_tobytes.return_value = b"fake image data" -# mock_page.get_pixmap.return_value = mock_pixmap - - -def test_crop_image_valid_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (10, 10, 50, 50) - result = crop_image(array, bbox) - assert result is not None - assert isinstance(result, str) - - -def test_crop_image_partial_outside_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (90, 90, 110, 110) - result = crop_image(array, bbox) - assert result is not None - assert isinstance(result, str) - - -def test_crop_image_completely_outside_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (110, 110, 120, 120) - result = crop_image(array, bbox) - assert result is None - - -def test_crop_image_zero_area_bbox(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (50, 50, 50, 50) - result = crop_image(array, bbox) - assert result is None - - -def test_crop_image_different_format(): - array = np.ones((100, 100, 3), dtype=np.uint8) * 255 - bbox = (10, 10, 50, 50) - result = crop_image(array, bbox, format="JPEG") - assert result is not None - assert isinstance(result, str) - - -def test_reverse_transform_bbox_no_offset(): - bbox = (10, 20, 30, 40) - bbox_offset = (0, 0) - expected_bbox = (10, 20, 30, 40) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_with_offset(): - bbox = (20, 30, 40, 50) - bbox_offset = (10, 10) - expected_bbox = (12, 25, 37, 50) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_with_large_offset(): - bbox = (60, 80, 90, 100) - bbox_offset = (20, 30) - width_ratio = (100 - 2 * bbox_offset[0]) / 100 - height_ratio = (100 - 2 * bbox_offset[1]) / 100 - expected_bbox = ( - int((60 - bbox_offset[0]) / width_ratio), - int((80 - bbox_offset[1]) / height_ratio), - int((90 - bbox_offset[0]) / width_ratio), - int((100 - bbox_offset[1]) / height_ratio), - ) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_custom_dimensions(): - bbox = (15, 25, 35, 45) - bbox_offset = (5, 5) - original_width = 200 - original_height = 200 - width_ratio = (original_width - 2 * bbox_offset[0]) / original_width - height_ratio = (original_height - 2 * bbox_offset[1]) / original_height - expected_bbox = ( - int((15 - bbox_offset[0]) / width_ratio), - int((25 - bbox_offset[1]) / height_ratio), - int((35 - bbox_offset[0]) / width_ratio), - int((45 - bbox_offset[1]) / height_ratio), - ) - transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) - - assert transformed_bbox == expected_bbox - - -def test_reverse_transform_bbox_zero_dimension(): - bbox = (10, 10, 20, 20) - bbox_offset = (0, 0) - original_width = 0 - original_height = 0 - with pytest.raises(ZeroDivisionError): - reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) - - -def test_postprocess_text_with_unaccepted_class(): - # Input text that should not be processed - txt = "This text should not be processed" - cls = "InvalidClass" # Not in ACCEPTED_CLASSES - - result = postprocess_text(txt, cls) - - assert result == "" - - -def test_postprocess_text_removes_tbc_and_processes_text(): - # Input text containing "" - txt = "Some text" - cls = "Title" # An accepted class - - expected_output = "Some text" - - result = postprocess_text(txt, cls) - - assert result == expected_output - - -def test_postprocess_text_no_tbc_but_accepted_class(): - # Input text without "" - txt = "This is a test **without** tbc" - cls = "Section-header" # An accepted class - - expected_output = "This is a test without tbc" - - result = postprocess_text(txt, cls) - - assert result == expected_output diff --git a/tests/nv_ingest/util/nim/test_doughnut.py b/tests/nv_ingest/util/nim/test_doughnut.py new file mode 100644 index 00000000..656acf8f --- /dev/null +++ b/tests/nv_ingest/util/nim/test_doughnut.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nv_ingest.util.nim.doughnut import extract_classes_bboxes +from nv_ingest.util.nim.doughnut import postprocess_text +from nv_ingest.util.nim.doughnut import reverse_transform_bbox +from nv_ingest.util.nim.doughnut import strip_markdown_formatting + + +def test_reverse_transform_bbox_no_offset(): + bbox = (10, 20, 30, 40) + bbox_offset = (0, 0) + expected_bbox = (10, 20, 30, 40) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_with_offset(): + bbox = (20, 30, 40, 50) + bbox_offset = (10, 10) + expected_bbox = (12, 25, 37, 50) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_with_large_offset(): + bbox = (60, 80, 90, 100) + bbox_offset = (20, 30) + width_ratio = (100 - 2 * bbox_offset[0]) / 100 + height_ratio = (100 - 2 * bbox_offset[1]) / 100 + expected_bbox = ( + int((60 - bbox_offset[0]) / width_ratio), + int((80 - bbox_offset[1]) / height_ratio), + int((90 - bbox_offset[0]) / width_ratio), + int((100 - bbox_offset[1]) / height_ratio), + ) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, 100, 100) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_custom_dimensions(): + bbox = (15, 25, 35, 45) + bbox_offset = (5, 5) + original_width = 200 + original_height = 200 + width_ratio = (original_width - 2 * bbox_offset[0]) / original_width + height_ratio = (original_height - 2 * bbox_offset[1]) / original_height + expected_bbox = ( + int((15 - bbox_offset[0]) / width_ratio), + int((25 - bbox_offset[1]) / height_ratio), + int((35 - bbox_offset[0]) / width_ratio), + int((45 - bbox_offset[1]) / height_ratio), + ) + transformed_bbox = reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) + + assert transformed_bbox == expected_bbox + + +def test_reverse_transform_bbox_zero_dimension(): + bbox = (10, 10, 20, 20) + bbox_offset = (0, 0) + original_width = 0 + original_height = 0 + with pytest.raises(ZeroDivisionError): + reverse_transform_bbox(bbox, bbox_offset, original_width, original_height) + + +def test_postprocess_text_with_unaccepted_class(): + # Input text that should not be processed + txt = "This text should not be processed" + cls = "InvalidClass" # Not in ACCEPTED_CLASSES + + result = postprocess_text(txt, cls) + + assert result == "" + + +def test_postprocess_text_removes_tbc_and_processes_text(): + # Input text containing "" + txt = "Some text" + cls = "Title" # An accepted class + + expected_output = "Some text" + + result = postprocess_text(txt, cls) + + assert result == expected_output + + +def test_postprocess_text_no_tbc_but_accepted_class(): + # Input text without "" + txt = "This is a test **without** tbc" + cls = "Section-header" # An accepted class + + expected_output = "This is a test without tbc" + + result = postprocess_text(txt, cls) + + assert result == expected_output + + +@pytest.mark.parametrize( + "input_text, expected_classes, expected_bboxes, expected_texts", + [ + ("Sample text", ["Text"], [(10, 20, 30, 40)], ["Sample text"]), + ( + "Invalid text ", + ["Bad-box", "Bad-box"], + [(0, 0, 0, 0), (10, 20, 30, 40)], + ["Invalid text", ""], + ), + ("Header content", ["Title"], [(15, 25, 35, 45)], ["Header content"]), + ("Overlapping box", ["Bad-box"], [(5, 10, 5, 10)], ["Overlapping box"]), + ], +) +def test_extract_classes_bboxes(input_text, expected_classes, expected_bboxes, expected_texts): + classes, bboxes, texts = extract_classes_bboxes(input_text) + assert classes == expected_classes + assert bboxes == expected_bboxes + assert texts == expected_texts + + +# Test cases for strip_markdown_formatting +@pytest.mark.parametrize( + "input_text, expected_output", + [ + ("# Header\n**Bold text**\n*Italic*", "Header\nBold text\nItalic"), + ("~~Strikethrough~~", "Strikethrough"), + ("[Link](http://example.com)", "Link"), + ("`inline code`", "inline code"), + ("> Blockquote", "Blockquote"), + ("Normal text\n\n\nMultiple newlines", "Normal text\n\nMultiple newlines"), + ("Dot sequence...... more text", "Dot sequence..... more text"), + ], +) +def test_strip_markdown_formatting(input_text, expected_output): + assert strip_markdown_formatting(input_text) == expected_output diff --git a/tests/nv_ingest_client/client/test_interface.py b/tests/nv_ingest_client/client/test_interface.py index f8b4ff8e..5f7e4428 100644 --- a/tests/nv_ingest_client/client/test_interface.py +++ b/tests/nv_ingest_client/client/test_interface.py @@ -10,16 +10,18 @@ from unittest.mock import patch import pytest -from nv_ingest_client.client import NvIngestClient from nv_ingest_client.client import Ingestor +from nv_ingest_client.client import NvIngestClient from nv_ingest_client.primitives import BatchJobSpec from nv_ingest_client.primitives.jobs import JobStateEnum +from nv_ingest_client.primitives.tasks import ChartExtractionTask from nv_ingest_client.primitives.tasks import DedupTask from nv_ingest_client.primitives.tasks import EmbedTask from nv_ingest_client.primitives.tasks import ExtractTask from nv_ingest_client.primitives.tasks import FilterTask from nv_ingest_client.primitives.tasks import SplitTask from nv_ingest_client.primitives.tasks import StoreTask +from nv_ingest_client.primitives.tasks import TableExtractionTask from nv_ingest_client.primitives.tasks import VdbUploadTask MODULE_UNDER_TEST = "nv_ingest_client.client.interface" @@ -80,7 +82,42 @@ def test_embed_task_some_args(ingestor): def test_extract_task_no_args(ingestor): ingestor.extract() - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], ExtractTask) + task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0] + assert isinstance(task, ExtractTask) + assert task._extract_tables is True + assert task._extract_charts is True + + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[1], TableExtractionTask) + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[2], ChartExtractionTask) + + +def test_extract_task_args_tables_false(ingestor): + ingestor.extract(extract_tables=False) + + task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0] + assert isinstance(task, ExtractTask) + assert task._extract_tables is False + assert task._extract_charts is True + + +def test_extract_task_args_charts_false(ingestor): + ingestor.extract(extract_charts=False) + + task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0] + assert isinstance(task, ExtractTask) + assert task._extract_tables is True + assert task._extract_charts is False + + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[1], TableExtractionTask) + + +def test_extract_task_args_tables_and_charts_false(ingestor): + ingestor.extract(extract_tables=False, extract_charts=False) + + task = ingestor._job_specs.job_specs["pdf"][0]._tasks[0] + assert isinstance(task, ExtractTask) + assert task._extract_tables is False + assert task._extract_charts is False def test_extract_task_some_args(ingestor): @@ -156,11 +193,13 @@ def test_chain(ingestor): assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[0], DedupTask) assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[1], EmbedTask) assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[2], ExtractTask) - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[3], FilterTask) - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[4], SplitTask) - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[5], StoreTask) - assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[6], VdbUploadTask) - assert len(ingestor._job_specs.job_specs["pdf"][0]._tasks) == 7 + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[3], TableExtractionTask) + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[4], ChartExtractionTask) + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[5], FilterTask) + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[6], SplitTask) + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[7], StoreTask) + assert isinstance(ingestor._job_specs.job_specs["pdf"][0]._tasks[8], VdbUploadTask) + assert len(ingestor._job_specs.job_specs["pdf"][0]._tasks) == 9 def test_ingest(ingestor, mock_client): @@ -190,8 +229,8 @@ def test_ingest_async(ingestor, mock_client): ingestor._job_states["job_id_1"] = MagicMock(state=JobStateEnum.COMPLETED) ingestor._job_states["job_id_2"] = MagicMock(state=JobStateEnum.FAILED) - mock_client.fetch_job_result.side_effect = ( - lambda job_id, *args, **kwargs: "result_1" if job_id == "job_id_1" else "result_2" + mock_client.fetch_job_result.side_effect = lambda job_id, *args, **kwargs: ( + "result_1" if job_id == "job_id_1" else "result_2" ) combined_future = ingestor.ingest_async(timeout=15)