diff --git a/docker-compose.yaml b/docker-compose.yaml
index 629170d6..fede5634 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -14,8 +14,6 @@ services:
- "8000:8000"
- "8001:8001"
- "8002:8002"
- volumes:
- - ${HOME}/.cache:/home/nvs/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
@@ -37,8 +35,6 @@ services:
- "8003:8000"
- "8004:8001"
- "8005:8002"
- volumes:
- - ${HOME}/.cache:/opt/nim/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
@@ -61,8 +57,6 @@ services:
- "8006:8000"
- "8007:8001"
- "8008:8002"
- volumes:
- - ${HOME}/.cache:/home/nvs/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
@@ -85,8 +79,6 @@ services:
- "8009:8000"
- "8010:8001"
- "8011:8002"
- volumes:
- - ${HOME}/.cache:/home/nvs/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
@@ -138,8 +130,8 @@ services:
- sys_nice
environment:
- CACHED_GRPC_ENDPOINT=cached:8001
- - CACHED_HEALTH_ENDPOINT=cached:8000
- - CACHED_HTTP_ENDPOINT=""
+ - CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer
+ - CACHED_INFER_PROTOCOL=grpc
- CUDA_VISIBLE_DEVICES=1
- DEPLOT_GRPC_ENDPOINT=""
# self hosted deplot
@@ -147,6 +139,7 @@ services:
- DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions
# build.nvidia.com hosted deplot
#- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot
+ - DEPLOT_INFER_PROTOCOL=http
- DOUGHNUT_GRPC_TRITON=triton-doughnut:8001
- INGEST_LOG_LEVEL=DEFAULT
- MESSAGE_CLIENT_HOST=redis
@@ -157,15 +150,13 @@ services:
- NVIDIA_BUILD_API_KEY=${NVIDIA_BUILD_API_KEY:-${NGC_API_KEY:-ngcapikey}}
- OTEL_EXPORTER_OTLP_ENDPOINT=otel-collector:4317
- PADDLE_GRPC_ENDPOINT=paddle:8001
- - PADDLE_HEALTH_ENDPOINT=paddle:8000
- - PADDLE_HTTP_ENDPOINT=""
+ - PADDLE_HTTP_ENDPOINT=http://paddle:8000/v1/infer
+ - PADDLE_INFER_PROTOCOL=grpc
- READY_CHECK_ALL_COMPONENTS=True
- REDIS_MORPHEUS_TASK_QUEUE=morpheus_task_queue
- - TABLE_DETECTION_GRPC_TRITON=yolox:8001
- - TABLE_DETECTION_HTTP_TRITON=""
- YOLOX_GRPC_ENDPOINT=yolox:8001
- - YOLOX_HEALTH_ENDPOINT=yolox:8000
- - YOLOX_HTTP_ENDPOINT=""
+ - YOLOX_HTTP_ENDPOINT=http://yolox:8000/v1/infer
+ - YOLOX_INFER_PROTOCOL=grpc
healthcheck:
test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1
interval: 10s
diff --git a/helm/values.yaml b/helm/values.yaml
index 3dd65739..33d2ad57 100644
--- a/helm/values.yaml
+++ b/helm/values.yaml
@@ -251,8 +251,6 @@ redis:
## @param envVars.MINIO_PUBLIC_ADDRESS [default: "http://localhost:9000"] Override this to publicly routable minio address, default assumes port-forwarding
## @param envVars.MINIO_BUCKET [default: "nv-ingest"] Override this for specific minio bucket to upload extracted images to
## @skip envVars.REDIS_MORPHEUS_TASK_QUEUE
-## @skip envVars.TABLE_DETECTION_GRPC_TRITON
-## @skip envVars.TABLE_DETECTION_HTTP_TRITON
## @skip envVars.CACHED_GRPC_ENDPOINT
## @skip envVars.CACHED_HTTP_ENDPOINT
## @skip envVars.PADDLE_GRPC_ENDPOINT
@@ -271,9 +269,6 @@ envVars:
MINIO_PUBLIC_ADDRESS: http://localhost:9000
MINIO_BUCKET: nv-ingest
- TABLE_DETECTION_GRPC_TRITON: nv-ingest-yolox:8001
- TABLE_DETECTION_HTTP_TRITON: ""
-
CACHED_GRPC_ENDPOINT: nv-ingest-cached:8001
CACHED_HTTP_ENDPOINT: ""
PADDLE_GRPC_ENDPOINT: nv-ingest-paddle:8001
diff --git a/src/nv_ingest/api/v1/health.py b/src/nv_ingest/api/v1/health.py
index 6fe4a9cd..cdc76906 100644
--- a/src/nv_ingest/api/v1/health.py
+++ b/src/nv_ingest/api/v1/health.py
@@ -66,14 +66,14 @@ async def get_ready_state() -> dict:
# for now to assume that if nv-ingest is running so is
# the pipeline.
morpheus_pipeline_ready = True
-
+
# We give the users an option to disable checking all distributed services for "readiness"
check_all_components = os.getenv("READY_CHECK_ALL_COMPONENTS", "True").lower()
if check_all_components in ['1', 'true', 'yes']:
- yolox_ready = is_ready(os.getenv("YOLOX_HEALTH_ENDPOINT", None), "/v1/health/ready")
- deplot_ready = is_ready(os.getenv("DEPLOT_HEALTH_ENDPOINT", None), "/v1/health/ready")
- cached_ready = is_ready(os.getenv("CACHED_HEALTH_ENDPOINT", None), "/v1/health/ready")
- paddle_ready = is_ready(os.getenv("PADDLE_HEALTH_ENDPOINT", None), "/v1/health/ready")
+ yolox_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready")
+ deplot_ready = is_ready(os.getenv("DEPLOT_HTTP_ENDPOINT", None), "/v1/health/ready")
+ cached_ready = is_ready(os.getenv("CACHED_HTTP_ENDPOINT", None), "/v1/health/ready")
+ paddle_ready = is_ready(os.getenv("PADDLE_HTTP_ENDPOINT", None), "/v1/health/ready")
if (ingest_ready
and morpheus_pipeline_ready
diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
index ae3abf15..ee66fcca 100644
--- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
+++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
@@ -35,7 +35,9 @@
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import call_image_inference_model
from nv_ingest.util.nim.helpers import create_inference_client
+from nv_ingest.util.nim.helpers import get_version
from nv_ingest.util.nim.helpers import perform_model_inference
+from nv_ingest.util.nim.helpers import preprocess_image_for_paddle
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import ImageChart
from nv_ingest.util.pdf.metadata_aggregators import ImageTable
@@ -47,7 +49,6 @@
from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy
from nv_ingest.util.pdf.pdfium import pdfium_try_get_bitmap_as_numpy
-
PADDLE_MIN_WIDTH = 32
PADDLE_MIN_HEIGHT = 32
YOLOX_MAX_BATCH_SIZE = 8
@@ -140,12 +141,19 @@ def extract_tables_and_charts_using_image_ensemble(
yolox_client = paddle_client = deplot_client = cached_client = None
try:
- yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token)
+ yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token, config.yolox_infer_protocol)
if extract_tables:
- paddle_client = create_inference_client(config.paddle_endpoints, config.auth_token)
+ paddle_client = create_inference_client(
+ config.paddle_endpoints, config.auth_token, config.paddle_infer_protocol
+ )
+ paddle_version = get_version(config.paddle_endpoints[1])
if extract_charts:
- cached_client = create_inference_client(config.cached_endpoints, config.auth_token)
- deplot_client = create_inference_client(config.deplot_endpoints, config.auth_token)
+ cached_client = create_inference_client(
+ config.cached_endpoints, config.auth_token, config.cached_infer_protocol
+ )
+ deplot_client = create_inference_client(
+ config.deplot_endpoints, config.auth_token, config.deplot_infer_protocol
+ )
batches = []
i = 0
@@ -164,7 +172,6 @@ def extract_tables_and_charts_using_image_ensemble(
input_array = prepare_images_for_inference(original_images)
output_array = perform_model_inference(yolox_client, "yolox", input_array, trace_info=trace_info)
-
results = process_inference_results(
output_array, original_image_shapes, num_classes, conf_thresh, iou_thresh, min_score, final_thresh
)
@@ -180,6 +187,7 @@ def extract_tables_and_charts_using_image_ensemble(
tables_and_charts,
extract_tables=extract_tables,
extract_charts=extract_charts,
+ paddle_version=paddle_version,
trace_info=trace_info,
)
@@ -313,6 +321,7 @@ def handle_table_chart_extraction(
tables_and_charts,
extract_tables=True,
extract_charts=True,
+ paddle_version=None,
trace_info=None,
):
"""
@@ -368,20 +377,28 @@ def handle_table_chart_extraction(
min_width=PADDLE_MIN_WIDTH,
min_height=PADDLE_MIN_HEIGHT,
)
+ if cropped is None:
+ continue
+
base64_img = numpy_to_base64(cropped)
+ if isinstance(paddle_client, grpcclient.InferenceServerClient):
+ cropped = preprocess_image_for_paddle(cropped, paddle_version=paddle_version)
+
table_content = call_image_inference_model(paddle_client, "paddle", cropped, trace_info=trace_info)
table_data = ImageTable(
content=table_content, image=base64_img, bbox=(w1, h1, w2, h2), max_width=width, max_height=height
)
tables_and_charts.append((page_idx, table_data))
+
elif extract_charts and label == "chart":
cropped = crop_image(original_image, (h1, w1, h2, w2))
+ if cropped is None:
+ continue
+
base64_img = numpy_to_base64(cropped)
- deplot_result = call_image_inference_model(
- deplot_client, "google/deplot", cropped, trace_info=trace_info
- )
+ deplot_result = call_image_inference_model(deplot_client, "deplot", cropped, trace_info=trace_info)
cached_result = call_image_inference_model(cached_client, "cached", cropped, trace_info=trace_info)
chart_content = join_cached_and_deplot_output(cached_result, deplot_result)
chart_data = ImageChart(
diff --git a/src/nv_ingest/schemas/pdf_extractor_schema.py b/src/nv_ingest/schemas/pdf_extractor_schema.py
index 2826bee4..85a1716b 100644
--- a/src/nv_ingest/schemas/pdf_extractor_schema.py
+++ b/src/nv_ingest/schemas/pdf_extractor_schema.py
@@ -64,6 +64,11 @@ class PDFiumConfigSchema(BaseModel):
paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
+ cached_infer_protocol: str = ""
+ deplot_infer_protocol: str = ""
+ paddle_infer_protocol: str = ""
+ yolox_infer_protocol: str = ""
+
identify_nearby_objects: bool = False
@root_validator(pre=True)
@@ -93,7 +98,8 @@ def clean_service(service):
return None
return service
- for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints", "yolox_endpoints"]:
+ for model_name in ["cached", "deplot", "paddle", "yolox"]:
+ endpoint_name = f"{model_name}_endpoints"
grpc_service, http_service = values.get(endpoint_name)
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)
@@ -103,6 +109,13 @@ def clean_service(service):
values[endpoint_name] = (grpc_service, http_service)
+ protocol_name = f"{model_name}_infer_protocol"
+ protocol_value = values.get(protocol_name)
+ if not protocol_value:
+ protocol_value = "http" if http_service else "grpc" if grpc_service else ""
+ protocol_value = protocol_value.lower()
+ values[protocol_name] = protocol_value
+
return values
class Config:
diff --git a/src/nv_ingest/util/image_processing/table_and_chart.py b/src/nv_ingest/util/image_processing/table_and_chart.py
index 9e79c9d7..72e73934 100644
--- a/src/nv_ingest/util/image_processing/table_and_chart.py
+++ b/src/nv_ingest/util/image_processing/table_and_chart.py
@@ -43,7 +43,13 @@ def join_cached_and_deplot_output(cached_text, deplot_text):
if (cached_text is not None):
try:
- cached_text_dict = json.loads(cached_text)
+ if isinstance(cached_text, str):
+ cached_text_dict = json.loads(cached_text)
+ elif isinstance(cached_text, dict):
+ cached_text_dict = cached_text
+ else:
+ cached_text_dict = {}
+
chart_content += cached_text_dict.get("chart_title", "")
if (deplot_text is not None):
diff --git a/src/nv_ingest/util/image_processing/transforms.py b/src/nv_ingest/util/image_processing/transforms.py
index 6b0acb03..f5e1dc11 100644
--- a/src/nv_ingest/util/image_processing/transforms.py
+++ b/src/nv_ingest/util/image_processing/transforms.py
@@ -18,7 +18,11 @@
def pad_image(
- array: np.ndarray, target_width: int = DEFAULT_MAX_WIDTH, target_height: int = DEFAULT_MAX_HEIGHT
+ array: np.ndarray,
+ target_width: int = DEFAULT_MAX_WIDTH,
+ target_height: int = DEFAULT_MAX_HEIGHT,
+ background_color: int = 255,
+ dtype=np.uint8,
) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
Pads a NumPy array representing an image to the specified target dimensions.
@@ -68,7 +72,7 @@ def pad_image(
final_width = max(width, target_width)
# Create the canvas and place the original image on it
- canvas = 255 * np.ones((final_height, final_width, array.shape[2]), dtype=np.uint8)
+ canvas = background_color * np.ones((final_height, final_width, array.shape[2]), dtype=dtype)
canvas[pad_height : pad_height + height, pad_width : pad_width + width] = array # noqa: E203
return canvas, (pad_width, pad_height)
@@ -113,6 +117,66 @@ def crop_image(
return cropped
+def normalize_image(
+ array: np.ndarray,
+ r_mean: float = 0.485,
+ g_mean: float = 0.456,
+ b_mean: float = 0.406,
+ r_std: float = 0.229,
+ g_std: float = 0.224,
+ b_std: float = 0.225,
+) -> np.ndarray:
+ """
+ Normalizes an RGB image by applying a mean and standard deviation to each channel.
+
+ Parameters:
+ ----------
+ array : np.ndarray
+ The input image array, which can be either grayscale or RGB. The image should have a shape of
+ (height, width, 3) for RGB images, or (height, width) or (height, width, 1) for grayscale images.
+ If a grayscale image is provided, it will be converted to RGB format by repeating the grayscale values
+ across all three channels (R, G, B).
+ r_mean : float, optional
+ The mean to be subtracted from the red channel (default is 0.485).
+ g_mean : float, optional
+ The mean to be subtracted from the green channel (default is 0.456).
+ b_mean : float, optional
+ The mean to be subtracted from the blue channel (default is 0.406).
+ r_std : float, optional
+ The standard deviation to divide the red channel by (default is 0.229).
+ g_std : float, optional
+ The standard deviation to divide the green channel by (default is 0.224).
+ b_std : float, optional
+ The standard deviation to divide the blue channel by (default is 0.225).
+
+ Returns:
+ -------
+ np.ndarray
+ A normalized image array with the same shape as the input, where the RGB channels have been normalized
+ by the given means and standard deviations.
+
+ Notes:
+ -----
+ The input pixel values should be in the range [0, 255], and the function scales these values to [0, 1]
+ before applying normalization.
+
+ If the input image is grayscale, it is converted to an RGB image by duplicating the grayscale values
+ across the three color channels.
+ """
+ # If the input is a grayscale image with shape (height, width) or (height, width, 1),
+ # convert it to RGB with shape (height, width, 3).
+ if array.ndim == 2 or array.shape[2] == 1:
+ array = np.dstack((array, 255 * np.ones_like(array), 255 * np.ones_like(array)))
+
+ height, width = array.shape[:2]
+
+ mean = np.array([r_mean, g_mean, b_mean]).reshape((1, 1, 3)).astype(np.float32)
+ std = np.array([r_std, g_std, b_std]).reshape((1, 1, 3)).astype(np.float32)
+ output_array = (array.astype("float32") / 255.0 - mean) / std
+
+ return output_array
+
+
def numpy_to_base64(array: np.ndarray) -> str:
"""
Converts a NumPy array representing an image to a base64-encoded string.
diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py
index 80af504c..4bdbfbf4 100644
--- a/src/nv_ingest/util/nim/helpers.py
+++ b/src/nv_ingest/util/nim/helpers.py
@@ -3,21 +3,36 @@
# SPDX-License-Identifier: Apache-2.0
import logging
+import re
+from typing import Any
+from typing import Dict
from typing import Optional
from typing import Tuple
+import backoff
+import cv2
import numpy as np
-import re
+import packaging
import requests
import tritonclient.grpc as grpcclient
+from nv_ingest.util.image_processing.transforms import normalize_image
from nv_ingest.util.image_processing.transforms import numpy_to_base64
+from nv_ingest.util.image_processing.transforms import pad_image
from nv_ingest.util.tracing.tagging import traceable_func
logger = logging.getLogger(__name__)
+DEPLOT_MAX_TOKENS = 128
+DEPLOT_TEMPERATURE = 1.0
+DEPLOT_TOP_P = 1.0
+
-def create_inference_client(endpoints: Tuple[str, str], auth_token: Optional[str]):
+def create_inference_client(
+ endpoints: Tuple[str, str],
+ auth_token: Optional[str] = None,
+ infer_protocol: Optional[str] = None,
+):
"""
Creates an inference client based on the provided endpoints.
@@ -36,17 +51,24 @@ def create_inference_client(endpoints: Tuple[str, str], auth_token: Optional[str
grpcclient.InferenceServerClient or dict
A gRPC client if the gRPC endpoint is provided, otherwise a dictionary containing the HTTP client details.
"""
- if endpoints[0] and endpoints[0].strip():
- logger.debug(f"Creating gRPC client with {endpoints}")
- return grpcclient.InferenceServerClient(url=endpoints[0])
- else:
- logger.debug(f"Creating HTTP client with {endpoints}")
+ grpc_endpoint, http_endpoint = endpoints
+
+ if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
+ infer_protocol = "grpc"
+
+ if infer_protocol == "grpc":
+ logger.debug(f"Creating gRPC client with {grpc_endpoint}")
+ return grpcclient.InferenceServerClient(url=grpc_endpoint)
+ elif infer_protocol == "http":
+ url = generate_url(http_endpoint)
+
+ logger.debug(f"Creating HTTP client with {http_endpoint}")
headers = {"accept": "application/json", "content-type": "application/json"}
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
- return {"endpoint_url": endpoints[1], "headers": headers}
+ return {"endpoint_url": url, "headers": headers}
@traceable_func(trace_name="pdf_content_extractor::{model_name}")
@@ -76,62 +98,113 @@ def call_image_inference_model(client, model_name: str, image_data):
If the HTTP request fails or if the response format is not as expected.
"""
if isinstance(client, grpcclient.InferenceServerClient):
- if image_data.ndim == 3:
- image_data = np.expand_dims(image_data, axis=0)
- inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")]
- inputs[0].set_data_from_numpy(image_data.astype(np.float32))
+ response = _call_image_inference_grpc_client(client, model_name, image_data)
+ else:
+ response = _call_image_inference_http_client(client, model_name, image_data)
+
+ return response
+
+
+def _call_image_inference_grpc_client(client, model_name: str, image_data):
+ if image_data.ndim == 3:
+ image_data = np.expand_dims(image_data, axis=0)
+ inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")]
+ inputs[0].set_data_from_numpy(image_data.astype(np.float32))
+
+ outputs = [grpcclient.InferRequestedOutput("output")]
+
+ try:
+ result = client.infer(model_name=model_name, inputs=inputs, outputs=outputs)
+ return " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")])
+ except Exception as e:
+ err_msg = f"Inference failed for model {model_name}: {str(e)}"
+ logger.error(err_msg)
+ raise RuntimeError(err_msg)
+
+
+def _call_image_inference_http_client(client, model_name: str, image_data):
+ base64_img = numpy_to_base64(image_data)
+
+ if model_name == "deplot":
+ payload = _prepare_deplot_payload(base64_img)
+ elif model_name in {"paddle", "cached", "yolox"}:
+ payload = _prepare_nim_payload(base64_img)
+ else:
+ raise ValueError(f"Model {model_name} is not supported.")
+
+ try:
+ url = client["endpoint_url"]
+ headers = client["headers"]
+
+ response = requests.post(url, json=payload, headers=headers)
+ response.raise_for_status() # Raise an exception for HTTP errors
- outputs = [grpcclient.InferRequestedOutput("output")]
+ # Parse the JSON response
+ json_response = response.json()
- try:
- result = client.infer(model_name=model_name, inputs=inputs, outputs=outputs)
- return " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")])
- except Exception as e:
- err_msg = f"Inference failed for model {model_name}: {str(e)}"
- logger.error(err_msg)
- raise RuntimeError(err_msg)
+ except requests.exceptions.RequestException as e:
+ raise RuntimeError(f"HTTP request failed: {e}")
+ except KeyError as e:
+ raise RuntimeError(f"Missing expected key in response: {e}")
+ except Exception as e:
+ raise RuntimeError(f"An error occurred during inference: {e}")
+ if model_name == "deplot":
+ result = _extract_content_from_deplot_response(json_response)
else:
- base64_img = numpy_to_base64(image_data)
-
- try:
- url = client["endpoint_url"]
- headers = client["headers"]
-
- messages = [
- {
- "role": "user",
- "content": f"Generate the underlying data table of the figure below: "
- f'',
- }
- ]
- payload = {
- "model": model_name,
- "messages": messages,
- "max_tokens": 128,
- "stream": False,
- "temperature": 1.0,
- "top_p": 1.0,
- }
-
- response = requests.post(url, json=payload, headers=headers)
- response.raise_for_status() # Raise an exception for HTTP errors
-
- # Parse the JSON response
- json_response = response.json()
-
- # Validate the response structure
- if "choices" not in json_response or not json_response["choices"]:
- raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
-
- return json_response["choices"][0]["message"]["content"]
-
- except requests.exceptions.RequestException as e:
- raise RuntimeError(f"HTTP request failed: {e}")
- except KeyError as e:
- raise RuntimeError(f"Missing expected key in response: {e}")
- except Exception as e:
- raise RuntimeError(f"An error occurred during inference: {e}")
+ result = _extract_content_from_nim_response(json_response)
+
+ return result
+
+
+def _prepare_deplot_payload(
+ base64_img: str,
+ max_tokens: int = DEPLOT_MAX_TOKENS,
+ temperature: float = DEPLOT_TEMPERATURE,
+ top_p: float = DEPLOT_TOP_P,
+) -> Dict[str, Any]:
+ messages = [
+ {
+ "role": "user",
+ "content": f"Generate the underlying data table of the figure below: "
+ f'',
+ }
+ ]
+ payload = {
+ "model": "google/deplot",
+ "messages": messages,
+ "max_tokens": max_tokens,
+ "stream": False,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+
+ return payload
+
+
+def _prepare_nim_payload(base64_img: str) -> Dict[str, Any]:
+ image_url = f"data:image/png;base64,{base64_img}"
+ image = {"type": "image_url", "image_url": {"url": image_url}}
+
+ message = {"content": [image]}
+ payload = {"messages": [message]}
+
+ return payload
+
+
+def _extract_content_from_deplot_response(json_response):
+ # Validate the response structure
+ if "choices" not in json_response or not json_response["choices"]:
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
+
+ return json_response["choices"][0]["message"]["content"]
+
+
+def _extract_content_from_nim_response(json_response):
+ if "data" not in json_response or not json_response["data"]:
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
+
+ return json_response["data"][0]["content"]
# Perform inference and return predictions
@@ -172,6 +245,63 @@ def perform_model_inference(client, model_name: str, input_array: np.ndarray):
return query_response.as_numpy("output")
+def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] = None) -> np.ndarray:
+ """
+ Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding,
+ and transposing it into the required format.
+
+ This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC.
+ It is not necessary when using the HTTP endpoint.
+
+ Steps:
+ -----
+ 1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels.
+ 2. Normalizes the image using the `normalize_image` function.
+ 3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR.
+ 4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR.
+
+ Parameters:
+ ----------
+ array : np.ndarray
+ The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
+
+ Returns:
+ -------
+ np.ndarray
+ A preprocessed image with the shape (channels, height, width) and normalized pixel values.
+ The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0.
+
+ Notes:
+ -----
+ - The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio.
+ - After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is
+ a requirement for PaddleOCR.
+ - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
+ """
+ if (not paddle_version) or (packaging.version.parse(paddle_version) < packaging.version.parse("0.2.0-rc1")):
+ return array
+
+ height, width = array.shape[:2]
+ scale_factor = 960 / max(height, width)
+ new_height = int(height * scale_factor)
+ new_width = int(width * scale_factor)
+ resized = cv2.resize(array, (new_width, new_height))
+
+ normalized = normalize_image(resized)
+
+ # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32.
+ new_height = (normalized.shape[0] + 31) // 32 * 32
+ new_width = (normalized.shape[1] + 31) // 32 * 32
+ padded, _ = pad_image(
+ normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32
+ )
+
+ # PaddleOCR NIM (GRPC) requires input to be (channel, height, width).
+ transposed = padded.transpose((2, 0, 1))
+
+ return transposed
+
+
def remove_url_endpoints(url) -> str:
"""Some configurations provide the full endpoint in the URL.
Ex: http://deplot:8000/v1/chat/completions. For hitting the
@@ -185,8 +315,8 @@ def remove_url_endpoints(url) -> str:
Returns:
str: URL with just the hostname:port portion remaining
"""
- if '/v1' in url:
- url = url.split('/v1')[0]
+ if "/v1" in url:
+ url = url.split("/v1")[0]
return url
@@ -204,26 +334,28 @@ def generate_url(url) -> str:
Returns:
str: Fully validated URL
"""
- if not re.match(r'^https?://', url):
+ if not re.match(r"^https?://", url):
# Add the default `http://` if its not already present in the URL
url = f"http://{url}"
- url = remove_url_endpoints(url)
-
return url
def is_ready(http_endpoint, ready_endpoint) -> bool:
-
# IF the url is empty or None that means the service was not configured
# and is therefore automatically marked as "ready"
- if http_endpoint is None or http_endpoint == '':
+ if http_endpoint is None or http_endpoint == "":
+ return True
+
+ # If the url is for build.nvidia.com, it is automatically assumed "ready"
+ if "ai.api.nvidia.com" in http_endpoint:
return True
url = generate_url(http_endpoint)
+ url = remove_url_endpoints(url)
- if not ready_endpoint.startswith('/') and not url.endswith('/'):
- ready_endpoint = '/' + ready_endpoint
+ if not ready_endpoint.startswith("/") and not url.endswith("/"):
+ ready_endpoint = "/" + ready_endpoint
url = url + ready_endpoint
@@ -258,3 +390,45 @@ def is_ready(http_endpoint, ready_endpoint) -> bool:
# Don't let anything squeeze by
logger.warning(f"Exception: {ex}")
return False
+
+
+@backoff.on_predicate(backoff.expo, max_value=5)
+def get_version(http_endpoint, metadata_endpoint="/v1/metadata", version_field="version") -> str:
+ if http_endpoint is None or http_endpoint == "":
+ return ""
+
+ url = generate_url(http_endpoint)
+ url = remove_url_endpoints(url)
+
+ if not metadata_endpoint.startswith("/") and not url.endswith("/"):
+ metadata_endpoint = "/" + metadata_endpoint
+
+ url = url + metadata_endpoint
+
+ # Call the metadata endpoint of the NIM
+ try:
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems resonable
+ resp = requests.get(url, timeout=5)
+ if resp.status_code == 200:
+ return resp.json().get(version_field, "")
+ else:
+ # Any other code is confusing. We should log it with a warning
+ # as it could be something that might hold up ready state
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}")
+ return ""
+ except requests.HTTPError as http_err:
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
+ return ""
+ except requests.Timeout:
+ logger.warning(f"'{url}' request timed out")
+ return ""
+ except ConnectionError:
+ logger.warning(f"A connection error for '{url}' occurred")
+ return ""
+ except requests.RequestException as err:
+ logger.warning(f"An error occurred: {err} for '{url}'")
+ return ""
+ except Exception as ex:
+ # Don't let anything squeeze by
+ logger.warning(f"Exception: {ex}")
+ return ""
diff --git a/src/pipeline.py b/src/pipeline.py
index b07295fe..af27e203 100644
--- a/src/pipeline.py
+++ b/src/pipeline.py
@@ -81,13 +81,14 @@ def get_caption_classifier_service():
return triton_service_caption_classifier, triton_service_caption_classifier_name
-def get_yolox_service_table_detection():
+def get_table_detection_service(env_var_prefix):
+ prefix = env_var_prefix.upper()
grpc_endpoint = os.environ.get(
- "TABLE_DETECTION_GRPC_TRITON",
+ f"{prefix}_GRPC_ENDPOINT",
"",
)
http_endpoint = os.environ.get(
- "TABLE_DETECTION_HTTP_TRITON",
+ f"{prefix}_HTTP_ENDPOINT",
"",
)
auth_token = os.environ.get(
@@ -97,80 +98,16 @@ def get_yolox_service_table_detection():
"NGC_API_KEY",
"",
)
-
- logger.info(f"TABLE_DETECTION_GRPC_TRITON: {grpc_endpoint}")
- logger.info(f"TABLE_DETECTION_HTTP_TRITON: {http_endpoint}")
-
- return grpc_endpoint, http_endpoint, auth_token
-
-
-def get_paddle_service_table_detection():
- grpc_endpoint = os.environ.get(
- "PADDLE_GRPC_ENDPOINT",
- "",
- )
- http_endpoint = os.environ.get(
- "PADDLE_HTTP_ENDPOINT",
- "",
- )
- auth_token = os.environ.get(
- "NVIDIA_BUILD_API_KEY",
- "",
- ) or os.environ.get(
- "NGC_API_KEY",
- "",
- )
-
- logger.info(f"PADDLE_GRPC_ENDPOINT: {grpc_endpoint}")
- logger.info(f"PADDLE_HTTP_ENDPOINT: {http_endpoint}")
-
- return grpc_endpoint, http_endpoint, auth_token
-
-
-def get_deplot_service_table_detection():
- grpc_endpoint = os.environ.get(
- "DEPLOT_GRPC_ENDPOINT",
- "",
- )
- http_endpoint = os.environ.get(
- "DEPLOT_HTTP_ENDPOINT",
- "",
- )
- auth_token = os.environ.get(
- "NVIDIA_BUILD_API_KEY",
- "",
- ) or os.environ.get(
- "NGC_API_KEY",
- "",
- )
-
- logger.info(f"DEPLOT_GRPC_ENDPOINT: {grpc_endpoint}")
- logger.info(f"DEPLOT_HTTP_ENDPOINT: {http_endpoint}")
-
- return grpc_endpoint, http_endpoint, auth_token
-
-
-def get_cached_service_table_detection():
- grpc_endpoint = os.environ.get(
- "CACHED_GRPC_ENDPOINT",
- "",
- )
- http_endpoint = os.environ.get(
- "CACHED_HTTP_ENDPOINT",
- "",
- )
- auth_token = os.environ.get(
- "NVIDIA_BUILD_API_KEY",
- "",
- ) or os.environ.get(
- "NGC_API_KEY",
- "",
+ infer_protocol = os.environ.get(
+ f"{prefix}_INFER_PROTOCOL",
+ "http" if http_endpoint else "grpc" if grpc_endpoint else "",
)
- logger.info(f"CACHED_GRPC_ENDPOINT: {grpc_endpoint}")
- logger.info(f"CACHED_HTTP_ENDPOINT: {http_endpoint}")
+ logger.info(f"{prefix}_GRPC_TRITON: {grpc_endpoint}")
+ logger.info(f"{prefix}_HTTP_TRITON: {http_endpoint}")
+ logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}")
- return grpc_endpoint, http_endpoint, auth_token
+ return grpc_endpoint, http_endpoint, auth_token, infer_protocol
def get_default_cpu_count():
@@ -247,10 +184,10 @@ def add_metadata_injector_stage(pipe, morpheus_pipeline_config):
def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count):
- yolox_grpc, yolox_http, yolox_auth = get_yolox_service_table_detection()
- paddle_grpc, paddle_http, paddle_auth = get_paddle_service_table_detection()
- deplot_grpc, deplot_http, deplot_auth = get_deplot_service_table_detection()
- cached_grpc, cached_http, cached_auth = get_cached_service_table_detection()
+ yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox")
+ paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle")
+ deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot")
+ cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached")
pdf_content_extractor_config = ingest_config.get(
"pdf_content_extraction_module",
{
@@ -259,6 +196,10 @@ def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, defau
"deplot_endpoints": (deplot_grpc, deplot_http),
"paddle_endpoints": (paddle_grpc, paddle_http),
"yolox_endpoints": (yolox_grpc, yolox_http),
+ "cached_infer_protocol": cached_protocol,
+ "deplot_infer_protocol": deplot_protocol,
+ "paddle_infer_protocol": paddle_protocol,
+ "yolox_infer_protocol": yolox_protocol,
"auth_token": yolox_auth, # All auth tokens are the same for the moment
}
},