From 06c797aa0a0ccdf2a7dc5e4db465f2486a584d8e Mon Sep 17 00:00:00 2001 From: Edward Kim <109497216+edknv@users.noreply.github.com> Date: Thu, 3 Oct 2024 06:50:20 -0700 Subject: [PATCH] Trace image model inference time in pdf extraction (#109) --- .../extraction_workflows/pdf/pdfium_helper.py | 23 ++-- .../modules/telemetry/otel_tracer.py | 120 ++++++++++++------ src/nv_ingest/stages/multiprocessing_stage.py | 62 +++++---- src/nv_ingest/stages/pdf_extractor_stage.py | 34 +++-- .../multi_processing/mp_pool_singleton.py | 7 +- src/nv_ingest/util/nim/helpers.py | 4 + src/nv_ingest/util/pdf/pdfium.py | 2 + src/nv_ingest/util/tracing/tagging.py | 103 +++++++++++++++ .../modules/telemetry/test_otel_tracer.py | 69 ++++++++++ tests/nv_ingest/util/tracing/test_tagging.py | 105 +++++++++++++++ 10 files changed, 442 insertions(+), 87 deletions(-) create mode 100644 tests/nv_ingest/modules/telemetry/test_otel_tracer.py diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index 9b43cf69..6474803d 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -19,6 +19,7 @@ import logging from math import log from typing import List +from typing import Optional from typing import Tuple import numpy as np @@ -70,6 +71,7 @@ def extract_tables_and_charts_using_image_ensemble( iou_thresh: float = YOLOX_IOU_THRESHOLD, min_score: float = YOLOX_MIN_SCORE, final_thresh: float = YOLOX_FINAL_SCORE, + trace_info: Optional[List] = None, ) -> List[Tuple[int, ImageTable]]: """ Extract tables and charts from a series of document pages using an ensemble of image-based models. @@ -146,12 +148,13 @@ def extract_tables_and_charts_using_image_ensemble( page_idx = 0 for batch in batches: - original_images, _ = pdfium_pages_to_numpy(batch, scale_tuple=(YOLOX_MAX_WIDTH, YOLOX_MAX_HEIGHT)) + original_images, _ = pdfium_pages_to_numpy(batch, scale_tuple=(YOLOX_MAX_WIDTH, YOLOX_MAX_HEIGHT), trace_info=trace_info) original_image_shapes = [image.shape for image in original_images] input_array = prepare_images_for_inference(original_images) - output_array = perform_model_inference(yolox_client, "yolox", input_array) + output_array = perform_model_inference(yolox_client, "yolox", input_array, trace_info=trace_info) + results = process_inference_results( output_array, original_image_shapes, num_classes, conf_thresh, iou_thresh, min_score, final_thresh ) @@ -165,6 +168,7 @@ def extract_tables_and_charts_using_image_ensemble( deplot_client, cached_client, tables_and_charts, + trace_info=trace_info, ) page_idx += 1 @@ -288,7 +292,7 @@ def process_inference_results( # Handle individual table/chart extraction and model inference def handle_table_chart_extraction( - annotation_dict, original_image, page_idx, paddle_client, deplot_client, cached_client, tables_and_charts + annotation_dict, original_image, page_idx, paddle_client, deplot_client, cached_client, tables_and_charts, trace_info=None, ): """ Handle the extraction of tables and charts from the inference results and run additional model inference. @@ -345,15 +349,15 @@ def handle_table_chart_extraction( ) base64_img = numpy_to_base64(cropped) - table_content = call_image_inference_model(paddle_client, "paddle", cropped) + table_content = call_image_inference_model(paddle_client, "paddle", cropped, trace_info=trace_info) table_data = ImageTable(table_content, base64_img, (w1, h1, w2, h2)) tables_and_charts.append((page_idx, table_data)) elif label == "chart": cropped = crop_image(original_image, (h1, w1, h2, w2)) base64_img = numpy_to_base64(cropped) - deplot_result = call_image_inference_model(deplot_client, "google/deplot", cropped) - cached_result = call_image_inference_model(cached_client, "cached", cropped) + deplot_result = call_image_inference_model(deplot_client, "google/deplot", cropped, trace_info=trace_info) + cached_result = call_image_inference_model(cached_client, "cached", cropped, trace_info=trace_info) chart_content = join_cached_and_deplot_output(cached_result, deplot_result) chart_data = ImageChart(chart_content, base64_img, (w1, h1, w2, h2)) tables_and_charts.append((page_idx, chart_data)) @@ -361,7 +365,7 @@ def handle_table_chart_extraction( # Define a helper function to use unstructured-io to extract text from a base64 # encoded bytestram PDF -def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables: bool, **kwargs): +def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables: bool, trace_info = None, **kwargs): """ Helper function to use pdfium to extract text from a bytestream PDF. @@ -385,7 +389,6 @@ def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables: str A string of extracted text. """ - logger.debug("Extracting PDF with pdfium backend.") row_data = kwargs.get("row_data") @@ -506,10 +509,10 @@ def pdfium(pdf_stream, extract_text: bool, extract_images: bool, extract_tables: extracted_data.append(text_extraction) if extract_tables: - for page_idx, table_and_charts in extract_tables_and_charts_using_image_ensemble(pages, pdfium_config): + for page_idx, table_and_charts in extract_tables_and_charts_using_image_ensemble(pages, pdfium_config, trace_info=trace_info): extracted_data.append( construct_table_and_chart_metadata( - table_and_charts, page_idx, pdf_metadata.page_count, source_metadata, base_unified_metadata + table_and_charts, page_idx, pdf_metadata.page_count, source_metadata, base_unified_metadata, ) ) diff --git a/src/nv_ingest/modules/telemetry/otel_tracer.py b/src/nv_ingest/modules/telemetry/otel_tracer.py index 9be2ecb0..05be65ae 100644 --- a/src/nv_ingest/modules/telemetry/otel_tracer.py +++ b/src/nv_ingest/modules/telemetry/otel_tracer.py @@ -71,29 +71,7 @@ def collect_timestamps(message): trace_id = int(trace_id, 16) span_id = RandomIdGenerator().generate_span_id() - timestamps = {} - for key, val in message.filter_timestamp("trace::exit::").items(): - exit_key = key - entry_key = exit_key.replace("trace::exit::", "trace::entry::") - ts_entry = message.get_timestamp(entry_key) - ts_exit = message.get_timestamp(exit_key) - job_name = key.replace("trace::exit::", "") - - ts_entry_ns = int(ts_entry.timestamp() * 1e9) - ts_exit_ns = int(ts_exit.timestamp() * 1e9) - - timestamps[job_name] = (ts_entry_ns, ts_exit_ns) - - task_results = {} - for key in message.list_metadata(): - if not key.startswith("annotation::"): - continue - task = message.get_metadata(key) - if not (("task_id" in task) and ("task_result" in task)): - continue - task_id = task["task_id"] - task_result = task["task_result"] - task_results[task_id] = task_result + timestamps = extract_timestamps_from_message(message) flattened = [x for t in timestamps.values() for x in t] start_time = min(flattened) @@ -107,20 +85,8 @@ def collect_timestamps(message): ) parent_ctx = trace.set_span_in_context(NonRecordingSpan(span_context)) parent_span = tracer.start_span(job_id, context=parent_ctx, start_time=start_time) - child_ctx = trace.set_span_in_context(parent_span) - for job_name, (ts_entry, ts_exit) in timestamps.items(): - span = tracer.start_span(job_name, context=child_ctx, start_time=ts_entry) - if job_name in task_results: - task_result = task_results[job_name] - if task_result == TaskResultStatus.SUCCESS.value: - span.set_status(Status(StatusCode.OK)) - if task_result == TaskResultStatus.FAILURE.value: - span.set_status(Status(StatusCode.ERROR)) - try: - span.add_event("entry", timestamp=ts_entry) - span.add_event("exit", timestamp=ts_exit) - finally: - span.end(end_time=ts_exit) + + create_span_with_timestamps(tracer, parent_span, message) if message.has_metadata("cm_failed") and message.get_metadata("cm_failed"): parent_span.set_status(Status(StatusCode.ERROR)) @@ -157,3 +123,83 @@ def on_next(message: ControlMessage) -> ControlMessage: builder.register_module_input("input", aggregate_node) builder.register_module_output("output", aggregate_node) + + +def extract_timestamps_from_message(message): + timestamps = {} + dedup_counter = {} + + for key, val in message.filter_timestamp("trace::exit::").items(): + exit_key = key + entry_key = exit_key.replace("trace::exit::", "trace::entry::") + + task_name = key.replace("trace::exit::", "") + if task_name in dedup_counter: + dedup_counter[task_name] += 1 + task_name = task_name + "_" + str(dedup_counter[task_name]) + else: + dedup_counter[task_name] = 0 + + ts_entry = message.get_timestamp(entry_key) + ts_exit = message.get_timestamp(exit_key) + ts_entry_ns = int(ts_entry.timestamp() * 1e9) + ts_exit_ns = int(ts_exit.timestamp() * 1e9) + + timestamps[task_name] = (ts_entry_ns, ts_exit_ns) + + return timestamps + + +def extract_annotated_task_results(message): + task_results = {} + for key in message.list_metadata(): + if not key.startswith("annotation::"): + continue + task = message.get_metadata(key) + if not (("task_id" in task) and ("task_result" in task)): + continue + task_id = task["task_id"] + task_result = task["task_result"] + task_results[task_id] = task_result + + return task_results + + +def create_span_with_timestamps(tracer, parent_span, message): + timestamps = extract_timestamps_from_message(message) + task_results = extract_annotated_task_results(message) + + ctx_store = {} + child_ctx = trace.set_span_in_context(parent_span) + for task_name, (ts_entry, ts_exit) in sorted(timestamps.items(), key=lambda x: x[1]): + main_task, *subtask = task_name.split("::", 1) + subtask = "::".join(subtask) + + if not subtask: + span = tracer.start_span(main_task, context=child_ctx, start_time=ts_entry) + else: + subtask_ctx = trace.set_span_in_context(ctx_store[main_task][0]) + span = tracer.start_span(subtask, context=subtask_ctx, start_time=ts_entry) + + span.add_event("entry", timestamp=ts_entry) + span.add_event("exit", timestamp=ts_exit) + + # Set success/failure status. + if task_name in task_results: + task_result = task_results[main_task] + if task_result == TaskResultStatus.SUCCESS.value: + span.set_status(Status(StatusCode.OK)) + if task_result == TaskResultStatus.FAILURE.value: + span.set_status(Status(StatusCode.ERROR)) + + # Add timestamps. + span.add_event("entry", timestamp=ts_entry) + span.add_event("exit", timestamp=ts_exit) + + # Cache span and exit time. + # Spans are used for looking up the main task's span when creating a subtask's span. + # Exit timestamps are used for closing each span at the very end. + ctx_store[task_name] = (span, ts_exit) + + for _, (span, ts_exit) in ctx_store.items(): + span.end(end_time=ts_exit) diff --git a/src/nv_ingest/stages/multiprocessing_stage.py b/src/nv_ingest/stages/multiprocessing_stage.py index 17cbfd1e..b7adc317 100644 --- a/src/nv_ingest/stages/multiprocessing_stage.py +++ b/src/nv_ingest/stages/multiprocessing_stage.py @@ -44,7 +44,7 @@ def trace_message(ctrl_msg, task_desc): """ ts_fetched = datetime.now() do_trace_tagging = (ctrl_msg.has_metadata("config::add_trace_tagging") is True) and ( - ctrl_msg.get_metadata("config::add_trace_tagging") is True + ctrl_msg.get_metadata("config::add_trace_tagging") is True ) if do_trace_tagging: @@ -149,14 +149,14 @@ class MultiProcessingBaseStage(SinglePortStage): """ def __init__( - self, - c: Config, - task: str, - task_desc: str, - pe_count: int, - process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], - document_type: str = None, - filter_properties: dict = None, + self, + c: Config, + task: str, + task_desc: str, + pe_count: int, + process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], + document_type: str = None, + filter_properties: dict = None, ): super().__init__(c) self._document_type = document_type @@ -199,11 +199,11 @@ def supports_cpp_node(self) -> bool: @staticmethod def work_package_input_handler( - work_package_input_queue: mp.Queue, - work_package_response_queue: mp.Queue, - cancellation_token: mp.Value, - process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], - process_pool: ProcessWorkerPoolSingleton, + work_package_input_queue: mp.Queue, + work_package_response_queue: mp.Queue, + cancellation_token: mp.Value, + process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], + process_pool: ProcessWorkerPoolSingleton, ): """ Processes work packages received from the recv_queue, applies the process_fn to each package, @@ -243,9 +243,13 @@ def work_package_input_handler( future = process_pool.submit_task(process_fn, (df, task_props)) # This can return/raise an exception - result = future.result() + result, *extra_results = future.result() work_package["payload"] = result + if extra_results: + for extra_result in extra_results: + if isinstance(extra_result, dict) and ("trace_info" in extra_result): + work_package["trace_info"] = extra_result["trace_info"] work_package_response_queue.put({"type": "on_next", "value": work_package}) except Exception as e: @@ -275,13 +279,13 @@ def work_package_input_handler( @staticmethod def work_package_response_handler( - mp_context, - max_queue_size, - work_package_input_queue: mp.Queue, - sub: mrc.Subscriber, - cancellation_token: mp.Value, - process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], - process_pool: ProcessWorkerPoolSingleton, + mp_context, + max_queue_size, + work_package_input_queue: mp.Queue, + sub: mrc.Subscriber, + cancellation_token: mp.Value, + process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], + process_pool: ProcessWorkerPoolSingleton, ): """ Manages child threads and collects results, forwarding them to the subscriber. @@ -478,11 +482,21 @@ def reconstruct_fn(work_package): def cm_func(ctrl_msg: ControlMessage, work_package: dict): # This is the first location where we have access to both the control message and the work package, # if we had any errors in the processing, raise them here. - if (work_package.get("error", False)): + if work_package.get("error", False): raise RuntimeError(work_package["error_message"]) gdf = cudf.from_pandas(work_package["payload"]) ctrl_msg.payload(MessageMeta(df=gdf)) + + do_trace_tagging = (ctrl_msg.has_metadata("config::add_trace_tagging") is True) and ( + ctrl_msg.get_metadata("config::add_trace_tagging") is True + ) + if do_trace_tagging: + trace_info = work_package.get("trace_info") + if trace_info: + for key, ts in trace_info.items(): + ctrl_msg.set_timestamp(key, ts) + return ctrl_msg return cm_func(ctrl_msg, work_package) @@ -523,7 +537,7 @@ def merge_fn(ctrl_msg: ControlMessage): The control message with updated tracing metadata. """ do_trace_tagging = (ctrl_msg.has_metadata("config::add_trace_tagging") is True) and ( - ctrl_msg.get_metadata("config::add_trace_tagging") is True + ctrl_msg.get_metadata("config::add_trace_tagging") is True ) if do_trace_tagging: diff --git a/src/nv_ingest/stages/pdf_extractor_stage.py b/src/nv_ingest/stages/pdf_extractor_stage.py index e5c07fc5..072fce49 100644 --- a/src/nv_ingest/stages/pdf_extractor_stage.py +++ b/src/nv_ingest/stages/pdf_extractor_stage.py @@ -9,6 +9,8 @@ import logging from typing import Any from typing import Dict +from typing import List +from typing import Optional import pandas as pd from morpheus.config import Config @@ -20,10 +22,11 @@ def decode_and_extract( - base64_row: Dict[str, Any], - task_props: Dict[str, Any], - validated_config: Any, - default: str = "pdfium" + base64_row: Dict[str, Any], + task_props: Dict[str, Any], + validated_config: Any, + default: str = "pdfium", + trace_info: Optional[List] = None, ) -> Any: """ Decodes base64 content from a row and extracts data from it using the specified extraction method. @@ -81,6 +84,8 @@ def decode_and_extract( if validated_config.pdfium_config is not None: extract_params["pdfium_config"] = validated_config.pdfium_config + if trace_info is not None: + extract_params["trace_info"] = trace_info if not hasattr(pdf, extract_method): extract_method = default @@ -101,7 +106,7 @@ def decode_and_extract( # exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id) -def process_pdf_bytes(df, task_props, validated_config): +def process_pdf_bytes(df, task_props, validated_config, trace_info=None): """ Processes a cuDF DataFrame containing PDF files in base64 encoding. Each PDF's content is replaced with its extracted text. @@ -113,11 +118,14 @@ def process_pdf_bytes(df, task_props, validated_config): Returns: - A pandas DataFrame with the PDF content replaced by the extracted text. """ + if trace_info is None: + trace_info = {} try: # Apply the helper function to each row in the 'content' column - _decode_and_extract = functools.partial(decode_and_extract, task_props=task_props, - validated_config=validated_config) + _decode_and_extract = functools.partial( + decode_and_extract, task_props=task_props, validated_config=validated_config, trace_info=trace_info + ) logger.debug(f"processing ({task_props.get('method', None)})") sr_extraction = df.apply(_decode_and_extract, axis=1) sr_extraction = sr_extraction.explode().dropna() @@ -127,7 +135,7 @@ def process_pdf_bytes(df, task_props, validated_config): else: extracted_df = pd.DataFrame({"document_type": [], "metadata": [], "uuid": []}) - return extracted_df + return extracted_df, {"trace_info": trace_info} except Exception as e: err_msg = f"Unhandled exception in process_pdf_bytes: {e}" @@ -137,11 +145,11 @@ def process_pdf_bytes(df, task_props, validated_config): def generate_pdf_extractor_stage( - c: Config, - extractor_config: Dict[str, Any], - task: str = "extract", - task_desc: str = "pdf_content_extractor", - pe_count: int = 24, + c: Config, + extractor_config: Dict[str, Any], + task: str = "extract", + task_desc: str = "pdf_content_extractor", + pe_count: int = 24, ): """ Helper function to generate a multiprocessing stage to perform pdf content extraction. diff --git a/src/nv_ingest/util/multi_processing/mp_pool_singleton.py b/src/nv_ingest/util/multi_processing/mp_pool_singleton.py index 6e3da80f..3aae3ff2 100644 --- a/src/nv_ingest/util/multi_processing/mp_pool_singleton.py +++ b/src/nv_ingest/util/multi_processing/mp_pool_singleton.py @@ -207,9 +207,10 @@ def _worker(task_queue: mp.Queue, manager: mp.Manager) -> None: logger.debug(f"Worker process {os.getpid()} received stop signal.") break - process_fn, args, future = task + future, process_fn, args = task + args, *kwargs = args try: - result = process_fn(*args[0]) + result = process_fn(*args, **{k: v for kwarg in kwargs for k, v in kwarg.items()}) future.set_result(result) except Exception as e: logger.error(f"Future result failure - {e}\n") @@ -232,7 +233,7 @@ def submit_task(self, process_fn: Callable, *args: Any) -> SimpleFuture: A future object representing the result of the task. """ future = SimpleFuture(self._manager) - self._task_queue.put((process_fn, args, future)) + self._task_queue.put((future, process_fn, args)) return future def close(self) -> None: diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 7f964a67..c05aa0f1 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -11,6 +11,7 @@ import tritonclient.grpc as grpcclient from nv_ingest.util.image_processing.transforms import numpy_to_base64 +from nv_ingest.util.tracing.tagging import traceable_func logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def create_inference_client(endpoints: Tuple[str, str], auth_token: Optional[str return {"endpoint_url": endpoints[1], "headers": headers} +@traceable_func(trace_name="pdf_content_extractor::{model_name}") def call_image_inference_model(client, model_name: str, image_data): """ Calls an image inference model using the provided client. @@ -87,6 +89,7 @@ def call_image_inference_model(client, model_name: str, image_data): err_msg = f"Inference failed for model {model_name}: {str(e)}" logger.error(err_msg) raise RuntimeError(err_msg) + else: base64_img = numpy_to_base64(image_data) @@ -131,6 +134,7 @@ def call_image_inference_model(client, model_name: str, image_data): # Perform inference and return predictions +@traceable_func(trace_name="pdf_content_extractor::{model_name}") def perform_model_inference(client, model_name: str, input_array: np.ndarray): """ Perform inference using the provided model and input data. diff --git a/src/nv_ingest/util/pdf/pdfium.py b/src/nv_ingest/util/pdf/pdfium.py index f8060b99..72843905 100644 --- a/src/nv_ingest/util/pdf/pdfium.py +++ b/src/nv_ingest/util/pdf/pdfium.py @@ -15,6 +15,7 @@ from PIL import Image from nv_ingest.util.image_processing.transforms import pad_image +from nv_ingest.util.tracing.tagging import traceable_func logger = logging.getLogger(__name__) @@ -115,6 +116,7 @@ def pdfium_try_get_bitmap_as_numpy(image_obj) -> np.ndarray: return img_array +@traceable_func(trace_name="pdf_content_extractor::pdfium_pages_to_numpy") def pdfium_pages_to_numpy( pages: List[pdfium.PdfPage], render_dpi=300, diff --git a/src/nv_ingest/util/tracing/tagging.py b/src/nv_ingest/util/tracing/tagging.py index 519fcd86..e151ab8c 100644 --- a/src/nv_ingest/util/tracing/tagging.py +++ b/src/nv_ingest/util/tracing/tagging.py @@ -4,6 +4,8 @@ import functools +import inspect +import string from datetime import datetime @@ -90,3 +92,104 @@ def wrapper_trace_tagging(*args, **kwargs): return wrapper_trace_tagging return decorator_trace_tagging + + +def traceable_func(trace_name=None, dedupe=True): + """ + A decorator that injects trace information for tracking the execution of a function. + It logs the entry and exit timestamps of the function in a `trace_info` dictionary, + which can be used for performance monitoring or debugging purposes. + + Parameters + ---------- + trace_name : str, optional + An optional string used as the prefix for the trace log entries. If not provided, + the decorated function's name is used. The string can include placeholders (e.g., + "pdf_extractor::{model_name}") that will be dynamically replaced with matching + function argument values. + dedupe : bool, optional + If True, ensures that the trace entry and exit keys are unique by appending an index + (e.g., `_0`, `_1`) to the keys if duplicate entries are detected. Default is True. + + Returns + ------- + function + A wrapped function that injects trace information before and after the function's + execution. + + Notes + ----- + - If `trace_info` is not provided in the keyword arguments, a new dictionary is created + and used for storing trace entries. + - If `trace_name` contains format placeholders, the decorator attempts to populate them + with matching argument values from the decorated function. + - The trace information is logged in the format: + - `trace::entry::{trace_name}` for the entry timestamp. + - `trace::exit::{trace_name}` for the exit timestamp. + - If `dedupe` is True, the trace keys will be appended with an index to avoid + overwriting existing entries. + + Example + ------- + >>> @traceable_func(trace_name="pdf_extractor::{model_name}") + >>> def extract_pdf(model_name): + ... pass + >>> trace_info = {} + >>> extract_pdf("my_model", trace_info=trace_info) + + In this example, `model_name` is dynamically replaced in the trace_name, and the + trace information is logged with unique keys if deduplication is enabled. + """ + + def decorator_inject_trace_info(func): + @functools.wraps(func) + def wrapper_inject_trace_info(*args, **kwargs): + trace_info = kwargs.pop("trace_info", None) + if trace_info is None: + trace_info = {} + trace_prefix = trace_name if trace_name else func.__name__ + + # If `trace_name` is a formattable string, e.g., "pdf_extractor::{model_name}", + # search `args` and `kwargs` to replace the placeholder. + placeholders = [x[1] for x in string.Formatter().parse(trace_name) if x[1] is not None] + if placeholders: + format_kwargs = {} + for name in placeholders: + arg_names = inspect.signature(func).parameters + if name in arg_names: + arg_val = dict(zip(arg_names, args))[name] + elif name in kwargs: + arg_val = kwargs.get(name) + else: + continue + format_kwargs[name] = arg_val + trace_prefix = trace_prefix.format(**format_kwargs) + + trace_entry_key = f"trace::entry::{trace_prefix}" + trace_exit_key = f"trace::exit::{trace_prefix}" + + ts_entry = datetime.now() + + if dedupe: + trace_entry_key += "_{}" + trace_exit_key += "_{}" + i = 0 + while (trace_entry_key.format(i) in trace_info) or (trace_exit_key.format(i) in trace_info): + i += 1 + trace_entry_key = trace_entry_key.format(i) + trace_exit_key = trace_exit_key.format(i) + + trace_info[trace_entry_key] = ts_entry + + # Call the decorated function + result = func(*args, **kwargs) + + ts_exit = datetime.now() + + trace_info[trace_exit_key] = ts_exit + + return result + + return wrapper_inject_trace_info + + return decorator_inject_trace_info diff --git a/tests/nv_ingest/modules/telemetry/test_otel_tracer.py b/tests/nv_ingest/modules/telemetry/test_otel_tracer.py new file mode 100644 index 00000000..d00ae7df --- /dev/null +++ b/tests/nv_ingest/modules/telemetry/test_otel_tracer.py @@ -0,0 +1,69 @@ +import cudf +from datetime import datetime +from morpheus.messages import ControlMessage + +from nv_ingest.modules.telemetry.otel_tracer import extract_annotated_task_results +from nv_ingest.modules.telemetry.otel_tracer import extract_timestamps_from_message + + +def test_extract_timestamps_single_task(): + msg = ControlMessage() + msg.set_timestamp("trace::entry::foo", datetime.fromtimestamp(1)) + msg.set_timestamp("trace::exit::foo", datetime.fromtimestamp(2)) + + expected_output = {"foo": (int(1e9), int(2e9))} # Convert seconds to nanoseconds + + result = extract_timestamps_from_message(msg) + + assert result == expected_output + + +def test_extract_timestamps_no_tasks(): + msg = ControlMessage() + + expected_output = {} + + result = extract_timestamps_from_message(msg) + + assert result == expected_output + + +def test_extract_annotated_task_results_invalid_metadata(): + msg = ControlMessage() + + # Simulate setting non-annotation metadata and valid annotation metadata + msg.set_metadata("random::metadata", {"random_key": "value"}) # Should be ignored + msg.set_metadata("annotation::task1", {"task_id": "task1", "task_result": "success"}) + + expected_output = {"task1": "success"} + + result = extract_annotated_task_results(msg) + + assert result == expected_output + + +def test_extract_annotated_task_results_missing_fields(): + msg = ControlMessage() + + # Simulate setting metadata with missing task_id and task_result + msg.set_metadata("annotation::task1", {"task_result": "success"}) # Missing task_id (should be skipped) + msg.set_metadata("annotation::task2", {"task_id": "task2"}) # Missing task_result (should be skipped) + + expected_output = {} + + result = extract_annotated_task_results(msg) + + assert result == expected_output + + +def test_extract_annotated_task_results_no_annotation_keys(): + msg = ControlMessage() + + # Simulate setting metadata with no annotation keys + msg.set_metadata("random::metadata", {"random_key": "value"}) + + expected_output = {} + + result = extract_annotated_task_results(msg) + + assert result == expected_output diff --git a/tests/nv_ingest/util/tracing/test_tagging.py b/tests/nv_ingest/util/tracing/test_tagging.py index 0027d1a1..6b66a3ce 100644 --- a/tests/nv_ingest/util/tracing/test_tagging.py +++ b/tests/nv_ingest/util/tracing/test_tagging.py @@ -7,6 +7,7 @@ import pytest from nv_ingest.util.tracing.tagging import traceable +from nv_ingest.util.tracing.tagging import traceable_func class MockControlMessage: @@ -82,3 +83,107 @@ def disabled_function(message): # Ensure no trace metadata was added since trace tagging was disabled assert not mock_control_message.filter_timestamp("trace::") + + +@traceable_func(trace_name="simple_func::{param}") +def simple_func(param, **kwargs): + return f"Processed {param}" + + +def test_traceable_func_without_trace_name(): + """ + Test that the traceable_func logs entry and exit times using the function name when no trace_name is provided. + """ + trace_info = {} + result = simple_func("sample_value", trace_info=trace_info) + + assert result == "Processed sample_value" + assert "trace::entry::simple_func::sample_value_0" in trace_info + assert "trace::exit::simple_func::sample_value_0" in trace_info + assert isinstance(trace_info["trace::entry::simple_func::sample_value_0"], datetime) + assert isinstance(trace_info["trace::exit::simple_func::sample_value_0"], datetime) + + +def test_traceable_func_with_trace_name_formatting(): + """ + Test that the traceable_func logs entry and exit times using the formatted trace_name with argument values. + """ + trace_info = {} + result = simple_func("formatted_value", trace_info=trace_info) + + assert result == "Processed formatted_value" + assert "trace::entry::simple_func::formatted_value_0" in trace_info + assert "trace::exit::simple_func::formatted_value_0" in trace_info + assert isinstance(trace_info["trace::entry::simple_func::formatted_value_0"], datetime) + assert isinstance(trace_info["trace::exit::simple_func::formatted_value_0"], datetime) + + +def test_traceable_func_dedupe(): + """ + Test that the traceable_func deduplicates trace keys by appending an index when dedupe=True. + """ + trace_info = {} + result1 = simple_func("dedupe_test", trace_info=trace_info) + result2 = simple_func("dedupe_test", trace_info=trace_info) + + assert result1 == "Processed dedupe_test" + assert result2 == "Processed dedupe_test" + + assert "trace::entry::simple_func::dedupe_test_0" in trace_info + assert "trace::exit::simple_func::dedupe_test_0" in trace_info + assert "trace::entry::simple_func::dedupe_test_1" in trace_info + assert "trace::exit::simple_func::dedupe_test_1" in trace_info + + assert isinstance(trace_info["trace::entry::simple_func::dedupe_test_0"], datetime) + assert isinstance(trace_info["trace::exit::simple_func::dedupe_test_0"], datetime) + assert isinstance(trace_info["trace::entry::simple_func::dedupe_test_1"], datetime) + assert isinstance(trace_info["trace::exit::simple_func::dedupe_test_1"], datetime) + + +def test_traceable_func_without_trace_info(): + """ + Test that traceable_func creates a new trace_info dictionary if one is not passed. + """ + result = simple_func("no_trace_info") + + assert result == "Processed no_trace_info" + + +def test_traceable_func_with_multiple_args(): + """ + Test that traceable_func handles functions with multiple arguments and formats trace_name accordingly. + """ + + @traceable_func(trace_name="multi_args_func::{arg1}::{arg2}") + def multi_args_func(arg1, arg2, **kwargs): + return f"Processed {arg1} and {arg2}" + + trace_info = {} + result = multi_args_func("first_value", "second_value", trace_info=trace_info) + + assert result == "Processed first_value and second_value" + assert "trace::entry::multi_args_func::first_value::second_value_0" in trace_info + assert "trace::exit::multi_args_func::first_value::second_value_0" in trace_info + + +def test_traceable_func_dedupe_disabled(): + """ + Test that the traceable_func does not deduplicate trace keys when dedupe=False. + """ + + @traceable_func(trace_name="no_dedupe_test", dedupe=False) + def no_dedupe_test(param, **kwargs): + return f"Processed {param}" + + trace_info = {} + result1 = no_dedupe_test("no_dedupe", trace_info=trace_info) + result2 = no_dedupe_test("no_dedupe", trace_info=trace_info) + + assert result1 == "Processed no_dedupe" + assert result2 == "Processed no_dedupe" + + assert "trace::entry::no_dedupe_test" in trace_info + assert "trace::exit::no_dedupe_test" in trace_info + + assert "trace::entry::no_dedupe_test_1" not in trace_info + assert "trace::exit::no_dedupe_test_1" not in trace_info