From adf3aae011c080a339494558d8a5442209097563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=98=E5=BF=A7=E5=8C=97=E8=90=B1=E8=8D=89?= Date: Sat, 23 Nov 2024 12:16:20 +0800 Subject: [PATCH 1/3] feat: onnx support --- pdf2zh/doclayout.py | 210 ++++++++++++++++++++++++++++++++++++++++++++ pdf2zh/pdf2zh.py | 25 +++--- 2 files changed, 220 insertions(+), 15 deletions(-) create mode 100644 pdf2zh/doclayout.py diff --git a/pdf2zh/doclayout.py b/pdf2zh/doclayout.py new file mode 100644 index 00000000..e4d7e3cb --- /dev/null +++ b/pdf2zh/doclayout.py @@ -0,0 +1,210 @@ +import abc +import cv2 +import numpy as np +import contextlib +from huggingface_hub import hf_hub_download + + +class DocLayoutModel(abc.ABC): + @staticmethod + def load_torch(): + model = TorchModel.from_pretrained( + repo_id="juliozhao/DocLayout-YOLO-DocStructBench", + filename="doclayout_yolo_docstructbench_imgsz1024.pt", + ) + return model + + @staticmethod + def load_onnx(): + model = OnnxModel.from_pretrained( + repo_id="wybxc/DocLayout-YOLO-DocStructBench-onnx", + filename="doclayout_yolo_docstructbench_imgsz1024.onnx", + ) + return model + + @staticmethod + def load_available(): + with contextlib.suppress(ImportError): + return DocLayoutModel.load_torch() + + with contextlib.suppress(ImportError): + return DocLayoutModel.load_onnx() + + raise ImportError( + "Please install the `torch` or `onnx` feature to use the DocLayout model." + ) + + @property + @abc.abstractmethod + def stride(self) -> int: + """Stride of the model input.""" + pass + + @abc.abstractmethod + def predict(self, image, imgsz=1024, **kwargs) -> list: + """ + Predict the layout of a document page. + + Args: + image: The image of the document page. + imgsz: Resize the image to this size. Must be a multiple of the stride. + **kwargs: Additional arguments. + """ + pass + + +class TorchModel(DocLayoutModel): + def __init__(self, model_path: str): + try: + import doclayout_yolo + except ImportError: + raise ImportError( + "Please install the `torch` feature to use the Torch model." + ) + + self.model_path = model_path + self.model = doclayout_yolo.YOLOv10(model_path) + + @staticmethod + def from_pretrained(repo_id: str, filename: str): + pth = hf_hub_download(repo_id=repo_id, filename=filename) + return TorchModel(pth) + + @property + def stride(self): + return 32 + + def predict(self, *args, **kwargs): + return self.model.predict(*args, **kwargs) + + +class YoloResult: + """Helper class to store detection results from ONNX model.""" + + def __init__(self, boxes, names): + self.boxes = [YoloBox(data=d) for d in boxes] + self.boxes.sort(key=lambda x: x.conf, reverse=True) + self.names = names + + +class YoloBox: + """Helper class to store detection results from ONNX model.""" + + def __init__(self, data): + self.xyxy = data[:4] + self.conf = data[-2] + self.cls = data[-1] + + +class OnnxModel(DocLayoutModel): + def __init__(self, model_path: str): + import ast + + try: + + import onnx + import onnxruntime + except ImportError: + raise ImportError( + "Please install the `onnx` feature to use the ONNX model." + ) + + self.model_path = model_path + + model = onnx.load(model_path) + metadata = {d.key: d.value for d in model.metadata_props} + self._stride = ast.literal_eval(metadata["stride"]) + self._names = ast.literal_eval(metadata["names"]) + + self.model = onnxruntime.InferenceSession(model.SerializeToString()) + + @staticmethod + def from_pretrained(repo_id: str, filename: str): + pth = hf_hub_download(repo_id=repo_id, filename=filename) + return OnnxModel(pth) + + @property + def stride(self): + return self._stride + + def resize_and_pad_image(self, image, new_shape): + """ + Resize and pad the image to the specified size, ensuring dimensions are multiples of stride. + + Parameters: + - image: Input image + - new_shape: Target size (integer or (height, width) tuple) + - stride: Padding alignment stride, default 32 + + Returns: + - Processed image + """ + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + h, w = image.shape[:2] + new_h, new_w = new_shape + + # Calculate scaling ratio + r = min(new_h / h, new_w / w) + resized_h, resized_w = int(round(h * r)), int(round(w * r)) + + # Resize image + image = cv2.resize( + image, (resized_w, resized_h), interpolation=cv2.INTER_LINEAR + ) + + # Calculate padding size and align to stride multiple + pad_w = (new_w - resized_w) % self.stride + pad_h = (new_h - resized_h) % self.stride + top, bottom = pad_h // 2, pad_h - pad_h // 2 + left, right = pad_w // 2, pad_w - pad_w // 2 + + # Add padding + image = cv2.copyMakeBorder( + image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) + + return image + + def scale_boxes(self, img1_shape, boxes, img0_shape): + """ + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + img0_shape (tuple): the shape of the target image, in the format of (height, width). + + Returns: + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + """ + + # Calculate scaling ratio + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) + + # Calculate padding size + pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1) + pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1) + + # Remove padding and scale boxes + boxes[..., :4] = (boxes[..., :4] - [pad_x, pad_y, pad_x, pad_y]) / gain + return boxes + + def predict(self, image, imgsz=1024, **kwargs): + # Preprocess input image + orig_h, orig_w = image.shape[:2] + pix = self.resize_and_pad_image(image, new_shape=imgsz) + pix = np.transpose(pix, (2, 0, 1)) # CHW + pix = np.expand_dims(pix, axis=0) # BCHW + pix = pix.astype(np.float32) / 255.0 # Normalize to [0, 1] + new_h, new_w = pix.shape[2:] + + # Run inference + preds = self.model.run(None, {"images": pix})[0] + + # Postprocess predictions + preds = preds[preds[..., 4] > 0.25] + preds[..., :4] = self.scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w)) + return [YoloResult(boxes=preds, names=self._names)] diff --git a/pdf2zh/pdf2zh.py b/pdf2zh/pdf2zh.py index 2cfbcd9e..14e1ae05 100644 --- a/pdf2zh/pdf2zh.py +++ b/pdf2zh/pdf2zh.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Container, Iterable, List, Optional import pymupdf -from huggingface_hub import hf_hub_download +from pathlib import Path from pdf2zh import __version__ from pdf2zh.pdfexceptions import PDFValueError @@ -27,10 +27,14 @@ def setup_log() -> None: - import doclayout_yolo - logging.basicConfig() - doclayout_yolo.utils.LOGGER.setLevel(logging.WARNING) + + try: + import doclayout_yolo + + doclayout_yolo.utils.LOGGER.setLevel(logging.WARNING) + except ImportError: + pass def check_files(files: List[str]) -> List[str]: @@ -73,8 +77,7 @@ def extract_text( output: str = "", **kwargs: Any, ) -> AnyIO: - import doclayout_yolo - + from pdf2zh.doclayout import DocLayoutModel import pdf2zh.high_level if not files: @@ -86,15 +89,7 @@ def extract_text( output_type = alttype outfp: AnyIO = sys.stdout - # pth = os.path.join(tempfile.gettempdir(), 'doclayout_yolo_docstructbench_imgsz1024.pt') - # if not os.path.exists(pth): - # print('Downloading...') - # urllib.request.urlretrieve("http://huggingface.co/juliozhao/DocLayout-YOLO-DocStructBench/resolve/main/doclayout_yolo_docstructbench_imgsz1024.pt",pth) - pth = hf_hub_download( - repo_id="juliozhao/DocLayout-YOLO-DocStructBench", - filename="doclayout_yolo_docstructbench_imgsz1024.pt", - ) - model = doclayout_yolo.YOLOv10(pth) + model = DocLayoutModel.load_available() for file in files: filename = os.path.splitext(os.path.basename(file))[0] From 679a4b25dde40cf3442b47e6e2e0f7337efe76a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=98=E5=BF=A7=E5=8C=97=E8=90=B1=E8=8D=89?= Date: Sun, 24 Nov 2024 01:10:04 +0800 Subject: [PATCH 2/3] feat!: make torch an optional dependency --- pdf2zh/doclayout.py | 4 +++- pdf2zh/high_level.py | 9 ++------- pdf2zh/utils.py | 13 +++++++++++++ pyproject.toml | 11 ++++++++--- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/pdf2zh/doclayout.py b/pdf2zh/doclayout.py index e4d7e3cb..cf4b62f0 100644 --- a/pdf2zh/doclayout.py +++ b/pdf2zh/doclayout.py @@ -206,5 +206,7 @@ def predict(self, image, imgsz=1024, **kwargs): # Postprocess predictions preds = preds[preds[..., 4] > 0.25] - preds[..., :4] = self.scale_boxes((new_h, new_w), preds[..., :4], (orig_h, orig_w)) + preds[..., :4] = self.scale_boxes( + (new_h, new_w), preds[..., :4], (orig_h, orig_w) + ) return [YoloResult(boxes=preds, names=self._names)] diff --git a/pdf2zh/high_level.py b/pdf2zh/high_level.py index 940d5df9..fbe51d6e 100644 --- a/pdf2zh/high_level.py +++ b/pdf2zh/high_level.py @@ -4,7 +4,6 @@ import sys from io import StringIO from typing import Any, BinaryIO, Container, Iterator, Optional, cast -import torch import numpy as np import tqdm from pymupdf import Document @@ -22,7 +21,7 @@ from pdf2zh.pdfexceptions import PDFValueError from pdf2zh.pdfinterp import PDFPageInterpreter, PDFResourceManager from pdf2zh.pdfpage import PDFPage -from pdf2zh.utils import AnyIO, FileOrName, open_filename +from pdf2zh.utils import AnyIO, FileOrName, open_filename, get_device def extract_text_to_fp( @@ -176,11 +175,7 @@ def extract_text_to_fp( pix.height, pix.width, 3 )[:, :, ::-1] page_layout = model.predict( - image, - imgsz=int(pix.height / 32) * 32, - device=( - "cuda:0" if torch.cuda.is_available() else "cpu" - ), # Auto-select GPU if available + image, imgsz=int(pix.height / 32) * 32, device=get_device() )[0] # kdtree 是不可能 kdtree 的,不如直接渲染成图片,用空间换时间 box = np.ones((pix.height, pix.width)) diff --git a/pdf2zh/utils.py b/pdf2zh/utils.py index 25697fdf..ad5643b8 100644 --- a/pdf2zh/utils.py +++ b/pdf2zh/utils.py @@ -819,3 +819,16 @@ def format_int_alpha(value: int) -> str: result.reverse() return "".join(result) + + +def get_device(): + """Get the device to use for computation.""" + try: + import torch + + if torch.cuda.is_available(): + return "cuda:0" + except ImportError: + pass + + return "cpu" diff --git a/pyproject.toml b/pyproject.toml index 2b3e5804..afede3c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Latex PDF Translator" authors = [{ name = "Byaidu", email = "byaidux@gmail.com" }] license = "AGPL-3.0" readme = "README.md" -requires-python = ">=3.8,<3.13" +requires-python = ">=3.9,<3.13" classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", @@ -17,7 +17,6 @@ dependencies = [ "pymupdf", "tqdm", "tenacity", - "doclayout-yolo", "numpy", "ollama", "deepl<1.19.1", @@ -25,10 +24,16 @@ dependencies = [ "azure-ai-translation-text<=1.0.1", "gradio", "huggingface_hub", - "torch", + "onnx", + "onnxruntime", + "opencv-python-headless", ] [project.optional-dependencies] +torch = [ + "doclayout-yolo", + "torch", +] dev = [ "black", "flake8", From d5eed6c33d6f5e9c959ec30e60c76af870f1fe32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=98=E5=BF=A7=E5=8C=97=E8=90=B1=E8=8D=89?= Date: Sun, 24 Nov 2024 01:17:43 +0800 Subject: [PATCH 3/3] chore: fix lint error --- pdf2zh/doclayout.py | 3 ++- pdf2zh/pdf2zh.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pdf2zh/doclayout.py b/pdf2zh/doclayout.py index cf4b62f0..1df95bcb 100644 --- a/pdf2zh/doclayout.py +++ b/pdf2zh/doclayout.py @@ -173,7 +173,8 @@ def scale_boxes(self, img1_shape, boxes, img0_shape): specified in (img1_shape) to the shape of a different image (img0_shape). Args: - img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + img1_shape (tuple): The shape of the image that the bounding boxes are for, + in the format of (height, width). boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) img0_shape (tuple): the shape of the target image, in the format of (height, width). diff --git a/pdf2zh/pdf2zh.py b/pdf2zh/pdf2zh.py index 14e1ae05..23b38809 100644 --- a/pdf2zh/pdf2zh.py +++ b/pdf2zh/pdf2zh.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Any, Container, Iterable, List, Optional import pymupdf -from pathlib import Path from pdf2zh import __version__ from pdf2zh.pdfexceptions import PDFValueError