Skip to content

Commit

Permalink
Move PaddleOCR preprocessing logic to nv-ingest (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Oct 15, 2024
1 parent c707a2b commit a18057a
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 186 deletions.
23 changes: 7 additions & 16 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ services:
- "8000:8000"
- "8001:8001"
- "8002:8002"
volumes:
- ${HOME}/.cache:/home/nvs/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
Expand All @@ -37,8 +35,6 @@ services:
- "8003:8000"
- "8004:8001"
- "8005:8002"
volumes:
- ${HOME}/.cache:/opt/nim/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
Expand All @@ -61,8 +57,6 @@ services:
- "8006:8000"
- "8007:8001"
- "8008:8002"
volumes:
- ${HOME}/.cache:/home/nvs/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
Expand All @@ -85,8 +79,6 @@ services:
- "8009:8000"
- "8010:8001"
- "8011:8002"
volumes:
- ${HOME}/.cache:/home/nvs/.cache
user: root
environment:
- NIM_HTTP_API_PORT=8000
Expand Down Expand Up @@ -138,15 +130,16 @@ services:
- sys_nice
environment:
- CACHED_GRPC_ENDPOINT=cached:8001
- CACHED_HEALTH_ENDPOINT=cached:8000
- CACHED_HTTP_ENDPOINT=""
- CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer
- CACHED_INFER_PROTOCOL=grpc
- CUDA_VISIBLE_DEVICES=1
- DEPLOT_GRPC_ENDPOINT=""
# self hosted deplot
- DEPLOT_HEALTH_ENDPOINT=deplot:8000
- DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions
# build.nvidia.com hosted deplot
#- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot
- DEPLOT_INFER_PROTOCOL=http
- DOUGHNUT_GRPC_TRITON=triton-doughnut:8001
- INGEST_LOG_LEVEL=DEFAULT
- MESSAGE_CLIENT_HOST=redis
Expand All @@ -157,15 +150,13 @@ services:
- NVIDIA_BUILD_API_KEY=${NVIDIA_BUILD_API_KEY:-${NGC_API_KEY:-ngcapikey}}
- OTEL_EXPORTER_OTLP_ENDPOINT=otel-collector:4317
- PADDLE_GRPC_ENDPOINT=paddle:8001
- PADDLE_HEALTH_ENDPOINT=paddle:8000
- PADDLE_HTTP_ENDPOINT=""
- PADDLE_HTTP_ENDPOINT=http://paddle:8000/v1/infer
- PADDLE_INFER_PROTOCOL=grpc
- READY_CHECK_ALL_COMPONENTS=True
- REDIS_MORPHEUS_TASK_QUEUE=morpheus_task_queue
- TABLE_DETECTION_GRPC_TRITON=yolox:8001
- TABLE_DETECTION_HTTP_TRITON=""
- YOLOX_GRPC_ENDPOINT=yolox:8001
- YOLOX_HEALTH_ENDPOINT=yolox:8000
- YOLOX_HTTP_ENDPOINT=""
- YOLOX_HTTP_ENDPOINT=http://yolox:8000/v1/infer
- YOLOX_INFER_PROTOCOL=grpc
healthcheck:
test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1
interval: 10s
Expand Down
5 changes: 0 additions & 5 deletions helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,6 @@ redis:
## @param envVars.MINIO_PUBLIC_ADDRESS [default: "http://localhost:9000"] Override this to publicly routable minio address, default assumes port-forwarding
## @param envVars.MINIO_BUCKET [default: "nv-ingest"] Override this for specific minio bucket to upload extracted images to
## @skip envVars.REDIS_MORPHEUS_TASK_QUEUE
## @skip envVars.TABLE_DETECTION_GRPC_TRITON
## @skip envVars.TABLE_DETECTION_HTTP_TRITON
## @skip envVars.CACHED_GRPC_ENDPOINT
## @skip envVars.CACHED_HTTP_ENDPOINT
## @skip envVars.PADDLE_GRPC_ENDPOINT
Expand All @@ -271,9 +269,6 @@ envVars:
MINIO_PUBLIC_ADDRESS: http://localhost:9000
MINIO_BUCKET: nv-ingest

TABLE_DETECTION_GRPC_TRITON: nv-ingest-yolox:8001
TABLE_DETECTION_HTTP_TRITON: ""

CACHED_GRPC_ENDPOINT: nv-ingest-cached:8001
CACHED_HTTP_ENDPOINT: ""
PADDLE_GRPC_ENDPOINT: nv-ingest-paddle:8001
Expand Down
10 changes: 5 additions & 5 deletions src/nv_ingest/api/v1/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ async def get_ready_state() -> dict:
# for now to assume that if nv-ingest is running so is
# the pipeline.
morpheus_pipeline_ready = True

# We give the users an option to disable checking all distributed services for "readiness"
check_all_components = os.getenv("READY_CHECK_ALL_COMPONENTS", "True").lower()
if check_all_components in ['1', 'true', 'yes']:
yolox_ready = is_ready(os.getenv("YOLOX_HEALTH_ENDPOINT", None), "/v1/health/ready")
deplot_ready = is_ready(os.getenv("DEPLOT_HEALTH_ENDPOINT", None), "/v1/health/ready")
cached_ready = is_ready(os.getenv("CACHED_HEALTH_ENDPOINT", None), "/v1/health/ready")
paddle_ready = is_ready(os.getenv("PADDLE_HEALTH_ENDPOINT", None), "/v1/health/ready")
yolox_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready")
deplot_ready = is_ready(os.getenv("DEPLOT_HTTP_ENDPOINT", None), "/v1/health/ready")
cached_ready = is_ready(os.getenv("CACHED_HTTP_ENDPOINT", None), "/v1/health/ready")
paddle_ready = is_ready(os.getenv("PADDLE_HTTP_ENDPOINT", None), "/v1/health/ready")

if (ingest_ready
and morpheus_pipeline_ready
Expand Down
35 changes: 26 additions & 9 deletions src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import call_image_inference_model
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.nim.helpers import get_version
from nv_ingest.util.nim.helpers import perform_model_inference
from nv_ingest.util.nim.helpers import preprocess_image_for_paddle
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import ImageChart
from nv_ingest.util.pdf.metadata_aggregators import ImageTable
Expand All @@ -47,7 +49,6 @@
from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy
from nv_ingest.util.pdf.pdfium import pdfium_try_get_bitmap_as_numpy


PADDLE_MIN_WIDTH = 32
PADDLE_MIN_HEIGHT = 32
YOLOX_MAX_BATCH_SIZE = 8
Expand Down Expand Up @@ -140,12 +141,19 @@ def extract_tables_and_charts_using_image_ensemble(

yolox_client = paddle_client = deplot_client = cached_client = None
try:
yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token)
yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token, config.yolox_infer_protocol)
if extract_tables:
paddle_client = create_inference_client(config.paddle_endpoints, config.auth_token)
paddle_client = create_inference_client(
config.paddle_endpoints, config.auth_token, config.paddle_infer_protocol
)
paddle_version = get_version(config.paddle_endpoints[1])
if extract_charts:
cached_client = create_inference_client(config.cached_endpoints, config.auth_token)
deplot_client = create_inference_client(config.deplot_endpoints, config.auth_token)
cached_client = create_inference_client(
config.cached_endpoints, config.auth_token, config.cached_infer_protocol
)
deplot_client = create_inference_client(
config.deplot_endpoints, config.auth_token, config.deplot_infer_protocol
)

batches = []
i = 0
Expand All @@ -164,7 +172,6 @@ def extract_tables_and_charts_using_image_ensemble(
input_array = prepare_images_for_inference(original_images)

output_array = perform_model_inference(yolox_client, "yolox", input_array, trace_info=trace_info)

results = process_inference_results(
output_array, original_image_shapes, num_classes, conf_thresh, iou_thresh, min_score, final_thresh
)
Expand All @@ -180,6 +187,7 @@ def extract_tables_and_charts_using_image_ensemble(
tables_and_charts,
extract_tables=extract_tables,
extract_charts=extract_charts,
paddle_version=paddle_version,
trace_info=trace_info,
)

Expand Down Expand Up @@ -313,6 +321,7 @@ def handle_table_chart_extraction(
tables_and_charts,
extract_tables=True,
extract_charts=True,
paddle_version=None,
trace_info=None,
):
"""
Expand Down Expand Up @@ -368,20 +377,28 @@ def handle_table_chart_extraction(
min_width=PADDLE_MIN_WIDTH,
min_height=PADDLE_MIN_HEIGHT,
)
if cropped is None:
continue

base64_img = numpy_to_base64(cropped)

if isinstance(paddle_client, grpcclient.InferenceServerClient):
cropped = preprocess_image_for_paddle(cropped, paddle_version=paddle_version)

table_content = call_image_inference_model(paddle_client, "paddle", cropped, trace_info=trace_info)
table_data = ImageTable(
content=table_content, image=base64_img, bbox=(w1, h1, w2, h2), max_width=width, max_height=height
)
tables_and_charts.append((page_idx, table_data))

elif extract_charts and label == "chart":
cropped = crop_image(original_image, (h1, w1, h2, w2))
if cropped is None:
continue

base64_img = numpy_to_base64(cropped)

deplot_result = call_image_inference_model(
deplot_client, "google/deplot", cropped, trace_info=trace_info
)
deplot_result = call_image_inference_model(deplot_client, "deplot", cropped, trace_info=trace_info)
cached_result = call_image_inference_model(cached_client, "cached", cropped, trace_info=trace_info)
chart_content = join_cached_and_deplot_output(cached_result, deplot_result)
chart_data = ImageChart(
Expand Down
15 changes: 14 additions & 1 deletion src/nv_ingest/schemas/pdf_extractor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class PDFiumConfigSchema(BaseModel):
paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)

cached_infer_protocol: str = ""
deplot_infer_protocol: str = ""
paddle_infer_protocol: str = ""
yolox_infer_protocol: str = ""

identify_nearby_objects: bool = False

@root_validator(pre=True)
Expand Down Expand Up @@ -93,7 +98,8 @@ def clean_service(service):
return None
return service

for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints", "yolox_endpoints"]:
for model_name in ["cached", "deplot", "paddle", "yolox"]:
endpoint_name = f"{model_name}_endpoints"
grpc_service, http_service = values.get(endpoint_name)
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)
Expand All @@ -103,6 +109,13 @@ def clean_service(service):

values[endpoint_name] = (grpc_service, http_service)

protocol_name = f"{model_name}_infer_protocol"
protocol_value = values.get(protocol_name)
if not protocol_value:
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
protocol_value = protocol_value.lower()
values[protocol_name] = protocol_value

return values

class Config:
Expand Down
8 changes: 7 additions & 1 deletion src/nv_ingest/util/image_processing/table_and_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def join_cached_and_deplot_output(cached_text, deplot_text):

if (cached_text is not None):
try:
cached_text_dict = json.loads(cached_text)
if isinstance(cached_text, str):
cached_text_dict = json.loads(cached_text)
elif isinstance(cached_text, dict):
cached_text_dict = cached_text
else:
cached_text_dict = {}

chart_content += cached_text_dict.get("chart_title", "")

if (deplot_text is not None):
Expand Down
68 changes: 66 additions & 2 deletions src/nv_ingest/util/image_processing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@


def pad_image(
array: np.ndarray, target_width: int = DEFAULT_MAX_WIDTH, target_height: int = DEFAULT_MAX_HEIGHT
array: np.ndarray,
target_width: int = DEFAULT_MAX_WIDTH,
target_height: int = DEFAULT_MAX_HEIGHT,
background_color: int = 255,
dtype=np.uint8,
) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
Pads a NumPy array representing an image to the specified target dimensions.
Expand Down Expand Up @@ -68,7 +72,7 @@ def pad_image(
final_width = max(width, target_width)

# Create the canvas and place the original image on it
canvas = 255 * np.ones((final_height, final_width, array.shape[2]), dtype=np.uint8)
canvas = background_color * np.ones((final_height, final_width, array.shape[2]), dtype=dtype)
canvas[pad_height : pad_height + height, pad_width : pad_width + width] = array # noqa: E203

return canvas, (pad_width, pad_height)
Expand Down Expand Up @@ -113,6 +117,66 @@ def crop_image(
return cropped


def normalize_image(
array: np.ndarray,
r_mean: float = 0.485,
g_mean: float = 0.456,
b_mean: float = 0.406,
r_std: float = 0.229,
g_std: float = 0.224,
b_std: float = 0.225,
) -> np.ndarray:
"""
Normalizes an RGB image by applying a mean and standard deviation to each channel.
Parameters:
----------
array : np.ndarray
The input image array, which can be either grayscale or RGB. The image should have a shape of
(height, width, 3) for RGB images, or (height, width) or (height, width, 1) for grayscale images.
If a grayscale image is provided, it will be converted to RGB format by repeating the grayscale values
across all three channels (R, G, B).
r_mean : float, optional
The mean to be subtracted from the red channel (default is 0.485).
g_mean : float, optional
The mean to be subtracted from the green channel (default is 0.456).
b_mean : float, optional
The mean to be subtracted from the blue channel (default is 0.406).
r_std : float, optional
The standard deviation to divide the red channel by (default is 0.229).
g_std : float, optional
The standard deviation to divide the green channel by (default is 0.224).
b_std : float, optional
The standard deviation to divide the blue channel by (default is 0.225).
Returns:
-------
np.ndarray
A normalized image array with the same shape as the input, where the RGB channels have been normalized
by the given means and standard deviations.
Notes:
-----
The input pixel values should be in the range [0, 255], and the function scales these values to [0, 1]
before applying normalization.
If the input image is grayscale, it is converted to an RGB image by duplicating the grayscale values
across the three color channels.
"""
# If the input is a grayscale image with shape (height, width) or (height, width, 1),
# convert it to RGB with shape (height, width, 3).
if array.ndim == 2 or array.shape[2] == 1:
array = np.dstack((array, 255 * np.ones_like(array), 255 * np.ones_like(array)))

height, width = array.shape[:2]

mean = np.array([r_mean, g_mean, b_mean]).reshape((1, 1, 3)).astype(np.float32)
std = np.array([r_std, g_std, b_std]).reshape((1, 1, 3)).astype(np.float32)
output_array = (array.astype("float32") / 255.0 - mean) / std

return output_array


def numpy_to_base64(array: np.ndarray) -> str:
"""
Converts a NumPy array representing an image to a base64-encoded string.
Expand Down
Loading

0 comments on commit a18057a

Please sign in to comment.