Skip to content

Commit

Permalink
Trace image model inference time in pdf extraction (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Oct 3, 2024
1 parent fb8147a commit 06c797a
Show file tree
Hide file tree
Showing 10 changed files with 442 additions and 87 deletions.
23 changes: 13 additions & 10 deletions src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
from math import log
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -70,6 +71,7 @@ def extract_tables_and_charts_using_image_ensemble(
iou_thresh: float = YOLOX_IOU_THRESHOLD,
min_score: float = YOLOX_MIN_SCORE,
final_thresh: float = YOLOX_FINAL_SCORE,
trace_info: Optional[List] = None,
) -> List[Tuple[int, ImageTable]]:
"""
Extract tables and charts from a series of document pages using an ensemble of image-based models.
Expand Down Expand Up @@ -146,12 +148,13 @@ def extract_tables_and_charts_using_image_ensemble(

page_idx = 0
for batch in batches:
original_images, _ = pdfium_pages_to_numpy(batch, scale_tuple=(YOLOX_MAX_WIDTH, YOLOX_MAX_HEIGHT))
original_images, _ = pdfium_pages_to_numpy(batch, scale_tuple=(YOLOX_MAX_WIDTH, YOLOX_MAX_HEIGHT), trace_info=trace_info)

original_image_shapes = [image.shape for image in original_images]
input_array = prepare_images_for_inference(original_images)

output_array = perform_model_inference(yolox_client, "yolox", input_array)
output_array = perform_model_inference(yolox_client, "yolox", input_array, trace_info=trace_info)

results = process_inference_results(
output_array, original_image_shapes, num_classes, conf_thresh, iou_thresh, min_score, final_thresh
)
Expand All @@ -165,6 +168,7 @@ def extract_tables_and_charts_using_image_ensemble(
deplot_client,
cached_client,
tables_and_charts,
trace_info=trace_info,
)

page_idx += 1
Expand Down Expand Up @@ -288,7 +292,7 @@ def process_inference_results(

# Handle individual table/chart extraction and model inference
def handle_table_chart_extraction(
annotation_dict, original_image, page_idx, paddle_client, deplot_client, cached_client, tables_and_charts
annotation_dict, original_image, page_idx, paddle_client, deplot_client, cached_client, tables_and_charts, trace_info=None,
):
"""
Handle the extraction of tables and charts from the inference results and run additional model inference.
Expand Down Expand Up @@ -345,23 +349,23 @@ def handle_table_chart_extraction(
)
base64_img = numpy_to_base64(cropped)

table_content = call_image_inference_model(paddle_client, "paddle", cropped)
table_content = call_image_inference_model(paddle_client, "paddle", cropped, trace_info=trace_info)
table_data = ImageTable(table_content, base64_img, (w1, h1, w2, h2))
tables_and_charts.append((page_idx, table_data))
elif label == "chart":
cropped = crop_image(original_image, (h1, w1, h2, w2))
base64_img = numpy_to_base64(cropped)

deplot_result = call_image_inference_model(deplot_client, "google/deplot", cropped)
cached_result = call_image_inference_model(cached_client, "cached", cropped)
deplot_result = call_image_inference_model(deplot_client, "google/deplot", cropped, trace_info=trace_info)
cached_result = call_image_inference_model(cached_client, "cached", cropped, trace_info=trace_info)
chart_content = join_cached_and_deplot_output(cached_result, deplot_result)
chart_data = ImageChart(chart_content, base64_img, (w1, h1, w2, h2))
tables_and_charts.append((page_idx, chart_data))


# Define a helper function to use unstructured-io to extract text from a base64
# encoded bytestram PDF
def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables: bool, **kwargs):
def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables: bool, trace_info = None, **kwargs):
"""
Helper function to use pdfium to extract text from a bytestream PDF.
Expand All @@ -385,7 +389,6 @@ def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables:
str
A string of extracted text.
"""

logger.debug("Extracting PDF with pdfium backend.")

row_data = kwargs.get("row_data")
Expand Down Expand Up @@ -506,10 +509,10 @@ def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables:
extracted_data.append(text_extraction)

if extract_tables:
for page_idx, table_and_charts in extract_tables_and_charts_using_image_ensemble(pages, pdfium_config):
for page_idx, table_and_charts in extract_tables_and_charts_using_image_ensemble(pages, pdfium_config, trace_info=trace_info):
extracted_data.append(
construct_table_and_chart_metadata(
table_and_charts, page_idx, pdf_metadata.page_count, source_metadata, base_unified_metadata
table_and_charts, page_idx, pdf_metadata.page_count, source_metadata, base_unified_metadata,
)
)

Expand Down
120 changes: 83 additions & 37 deletions src/nv_ingest/modules/telemetry/otel_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,7 @@ def collect_timestamps(message):
trace_id = int(trace_id, 16)
span_id = RandomIdGenerator().generate_span_id()

timestamps = {}
for key, val in message.filter_timestamp("trace::exit::").items():
exit_key = key
entry_key = exit_key.replace("trace::exit::", "trace::entry::")
ts_entry = message.get_timestamp(entry_key)
ts_exit = message.get_timestamp(exit_key)
job_name = key.replace("trace::exit::", "")

ts_entry_ns = int(ts_entry.timestamp() * 1e9)
ts_exit_ns = int(ts_exit.timestamp() * 1e9)

timestamps[job_name] = (ts_entry_ns, ts_exit_ns)

task_results = {}
for key in message.list_metadata():
if not key.startswith("annotation::"):
continue
task = message.get_metadata(key)
if not (("task_id" in task) and ("task_result" in task)):
continue
task_id = task["task_id"]
task_result = task["task_result"]
task_results[task_id] = task_result
timestamps = extract_timestamps_from_message(message)

flattened = [x for t in timestamps.values() for x in t]
start_time = min(flattened)
Expand All @@ -107,20 +85,8 @@ def collect_timestamps(message):
)
parent_ctx = trace.set_span_in_context(NonRecordingSpan(span_context))
parent_span = tracer.start_span(job_id, context=parent_ctx, start_time=start_time)
child_ctx = trace.set_span_in_context(parent_span)
for job_name, (ts_entry, ts_exit) in timestamps.items():
span = tracer.start_span(job_name, context=child_ctx, start_time=ts_entry)
if job_name in task_results:
task_result = task_results[job_name]
if task_result == TaskResultStatus.SUCCESS.value:
span.set_status(Status(StatusCode.OK))
if task_result == TaskResultStatus.FAILURE.value:
span.set_status(Status(StatusCode.ERROR))
try:
span.add_event("entry", timestamp=ts_entry)
span.add_event("exit", timestamp=ts_exit)
finally:
span.end(end_time=ts_exit)

create_span_with_timestamps(tracer, parent_span, message)

if message.has_metadata("cm_failed") and message.get_metadata("cm_failed"):
parent_span.set_status(Status(StatusCode.ERROR))
Expand Down Expand Up @@ -157,3 +123,83 @@ def on_next(message: ControlMessage) -> ControlMessage:

builder.register_module_input("input", aggregate_node)
builder.register_module_output("output", aggregate_node)


def extract_timestamps_from_message(message):
timestamps = {}
dedup_counter = {}

for key, val in message.filter_timestamp("trace::exit::").items():
exit_key = key
entry_key = exit_key.replace("trace::exit::", "trace::entry::")

task_name = key.replace("trace::exit::", "")
if task_name in dedup_counter:
dedup_counter[task_name] += 1
task_name = task_name + "_" + str(dedup_counter[task_name])
else:
dedup_counter[task_name] = 0

ts_entry = message.get_timestamp(entry_key)
ts_exit = message.get_timestamp(exit_key)
ts_entry_ns = int(ts_entry.timestamp() * 1e9)
ts_exit_ns = int(ts_exit.timestamp() * 1e9)

timestamps[task_name] = (ts_entry_ns, ts_exit_ns)

return timestamps


def extract_annotated_task_results(message):
task_results = {}
for key in message.list_metadata():
if not key.startswith("annotation::"):
continue
task = message.get_metadata(key)
if not (("task_id" in task) and ("task_result" in task)):
continue
task_id = task["task_id"]
task_result = task["task_result"]
task_results[task_id] = task_result

return task_results


def create_span_with_timestamps(tracer, parent_span, message):
timestamps = extract_timestamps_from_message(message)
task_results = extract_annotated_task_results(message)

ctx_store = {}
child_ctx = trace.set_span_in_context(parent_span)
for task_name, (ts_entry, ts_exit) in sorted(timestamps.items(), key=lambda x: x[1]):
main_task, *subtask = task_name.split("::", 1)
subtask = "::".join(subtask)

if not subtask:
span = tracer.start_span(main_task, context=child_ctx, start_time=ts_entry)
else:
subtask_ctx = trace.set_span_in_context(ctx_store[main_task][0])
span = tracer.start_span(subtask, context=subtask_ctx, start_time=ts_entry)

span.add_event("entry", timestamp=ts_entry)
span.add_event("exit", timestamp=ts_exit)

# Set success/failure status.
if task_name in task_results:
task_result = task_results[main_task]
if task_result == TaskResultStatus.SUCCESS.value:
span.set_status(Status(StatusCode.OK))
if task_result == TaskResultStatus.FAILURE.value:
span.set_status(Status(StatusCode.ERROR))

# Add timestamps.
span.add_event("entry", timestamp=ts_entry)
span.add_event("exit", timestamp=ts_exit)

# Cache span and exit time.
# Spans are used for looking up the main task's span when creating a subtask's span.
# Exit timestamps are used for closing each span at the very end.
ctx_store[task_name] = (span, ts_exit)

for _, (span, ts_exit) in ctx_store.items():
span.end(end_time=ts_exit)
62 changes: 38 additions & 24 deletions src/nv_ingest/stages/multiprocessing_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def trace_message(ctrl_msg, task_desc):
"""
ts_fetched = datetime.now()
do_trace_tagging = (ctrl_msg.has_metadata("config::add_trace_tagging") is True) and (
ctrl_msg.get_metadata("config::add_trace_tagging") is True
ctrl_msg.get_metadata("config::add_trace_tagging") is True
)

if do_trace_tagging:
Expand Down Expand Up @@ -149,14 +149,14 @@ class MultiProcessingBaseStage(SinglePortStage):
"""

def __init__(
self,
c: Config,
task: str,
task_desc: str,
pe_count: int,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
document_type: str = None,
filter_properties: dict = None,
self,
c: Config,
task: str,
task_desc: str,
pe_count: int,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
document_type: str = None,
filter_properties: dict = None,
):
super().__init__(c)
self._document_type = document_type
Expand Down Expand Up @@ -199,11 +199,11 @@ def supports_cpp_node(self) -> bool:

@staticmethod
def work_package_input_handler(
work_package_input_queue: mp.Queue,
work_package_response_queue: mp.Queue,
cancellation_token: mp.Value,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
process_pool: ProcessWorkerPoolSingleton,
work_package_input_queue: mp.Queue,
work_package_response_queue: mp.Queue,
cancellation_token: mp.Value,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
process_pool: ProcessWorkerPoolSingleton,
):
"""
Processes work packages received from the recv_queue, applies the process_fn to each package,
Expand Down Expand Up @@ -243,9 +243,13 @@ def work_package_input_handler(
future = process_pool.submit_task(process_fn, (df, task_props))

# This can return/raise an exception
result = future.result()
result, *extra_results = future.result()

work_package["payload"] = result
if extra_results:
for extra_result in extra_results:
if isinstance(extra_result, dict) and ("trace_info" in extra_result):
work_package["trace_info"] = extra_result["trace_info"]

work_package_response_queue.put({"type": "on_next", "value": work_package})
except Exception as e:
Expand Down Expand Up @@ -275,13 +279,13 @@ def work_package_input_handler(

@staticmethod
def work_package_response_handler(
mp_context,
max_queue_size,
work_package_input_queue: mp.Queue,
sub: mrc.Subscriber,
cancellation_token: mp.Value,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
process_pool: ProcessWorkerPoolSingleton,
mp_context,
max_queue_size,
work_package_input_queue: mp.Queue,
sub: mrc.Subscriber,
cancellation_token: mp.Value,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
process_pool: ProcessWorkerPoolSingleton,
):
"""
Manages child threads and collects results, forwarding them to the subscriber.
Expand Down Expand Up @@ -478,11 +482,21 @@ def reconstruct_fn(work_package):
def cm_func(ctrl_msg: ControlMessage, work_package: dict):
# This is the first location where we have access to both the control message and the work package,
# if we had any errors in the processing, raise them here.
if (work_package.get("error", False)):
if work_package.get("error", False):
raise RuntimeError(work_package["error_message"])

gdf = cudf.from_pandas(work_package["payload"])
ctrl_msg.payload(MessageMeta(df=gdf))

do_trace_tagging = (ctrl_msg.has_metadata("config::add_trace_tagging") is True) and (
ctrl_msg.get_metadata("config::add_trace_tagging") is True
)
if do_trace_tagging:
trace_info = work_package.get("trace_info")
if trace_info:
for key, ts in trace_info.items():
ctrl_msg.set_timestamp(key, ts)

return ctrl_msg

return cm_func(ctrl_msg, work_package)
Expand Down Expand Up @@ -523,7 +537,7 @@ def merge_fn(ctrl_msg: ControlMessage):
The control message with updated tracing metadata.
"""
do_trace_tagging = (ctrl_msg.has_metadata("config::add_trace_tagging") is True) and (
ctrl_msg.get_metadata("config::add_trace_tagging") is True
ctrl_msg.get_metadata("config::add_trace_tagging") is True
)

if do_trace_tagging:
Expand Down
Loading

0 comments on commit 06c797a

Please sign in to comment.