From 442171a42de5e4833b61c89b0ace4a0a83a2e2d1 Mon Sep 17 00:00:00 2001 From: Joker1212 <519548295@qq.com> Date: Thu, 19 Sep 2024 18:02:36 +0800 Subject: [PATCH 1/4] feat: add text oritation small angle fix --- wired_table_rec/table_line_rec.py | 4 +- wired_table_rec/table_line_rec_plus.py | 2 +- wired_table_rec/table_recover.py | 62 +++++++++++++------------- wired_table_rec/utils.py | 45 +++++++++++++++++++ 4 files changed, 80 insertions(+), 33 deletions(-) diff --git a/wired_table_rec/table_line_rec.py b/wired_table_rec/table_line_rec.py index 7be1004..2447066 100644 --- a/wired_table_rec/table_line_rec.py +++ b/wired_table_rec/table_line_rec.py @@ -48,7 +48,9 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: [box_4_2_poly_to_box_4_1(box) for box in polygons] ) polygons = np.delete(polygons, list(del_idxs), axis=0) - _, idx = sorted_ocr_boxes([box_4_2_poly_to_box_4_1(box) for box in polygons]) + _, idx = sorted_ocr_boxes( + [box_4_2_poly_to_box_4_1(box) for box in polygons], threhold=0.4 + ) polygons = polygons[idx] polygons = merge_adjacent_polys(polygons) return polygons diff --git a/wired_table_rec/table_line_rec_plus.py b/wired_table_rec/table_line_rec_plus.py index 8880891..f4e2c77 100644 --- a/wired_table_rec/table_line_rec_plus.py +++ b/wired_table_rec/table_line_rec_plus.py @@ -44,7 +44,7 @@ def __call__(self, img: np.ndarray) -> Optional[np.ndarray]: polygons[:, 3, :].copy(), ) _, idx = sorted_ocr_boxes( - [box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons] + [box_4_2_poly_to_box_4_1(poly_box) for poly_box in polygons], threhold=0.4 ) polygons = polygons[idx] return polygons diff --git a/wired_table_rec/table_recover.py b/wired_table_rec/table_recover.py index be0502c..16e409c 100644 --- a/wired_table_rec/table_recover.py +++ b/wired_table_rec/table_recover.py @@ -114,38 +114,38 @@ def get_benchmark_rows( ) -> Tuple[np.ndarray, List[float], int]: leftmost_cell_idxs = [v[0] for v in rows.values()] benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1] - - theta = 15 + # 有线表格模型精度足够,不进行工程化修正,避免更多未知问题 + # theta = 15 # 遍历其他所有的框,按照y轴进行区间划分 - range_res = {} - for cur_idx, cur_box in enumerate(polygons): - # fix cur_idx in benchmark_x - if cur_idx in leftmost_cell_idxs: - continue - - cur_y = cur_box[0, 1] - - start_idx, end_idx = None, None - for i, v in enumerate(benchmark_x): - if cur_y - theta <= v <= cur_y + theta: - break - - if cur_y > v: - start_idx = i - continue - - if cur_y < v: - end_idx = i - break - - range_res[cur_idx] = [start_idx, end_idx] - - sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True)) - for k, v in sorted_res.items(): - if not all(v): - continue - - benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1]) + # range_res = {} + # for cur_idx, cur_box in enumerate(polygons): + # # fix cur_idx in benchmark_x + # if cur_idx in leftmost_cell_idxs: + # continue + # + # cur_y = cur_box[0, 1] + # + # start_idx, end_idx = None, None + # for i, v in enumerate(benchmark_x): + # if cur_y - theta <= v <= cur_y + theta: + # break + # + # if cur_y > v: + # start_idx = i + # continue + # + # if cur_y < v: + # end_idx = i + # break + # + # range_res[cur_idx] = [start_idx, end_idx] + # + # sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True)) + # for k, v in sorted_res.items(): + # if not all(v): + # continue + # + # benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1]) each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist() diff --git a/wired_table_rec/utils.py b/wired_table_rec/utils.py index a721751..d69676c 100644 --- a/wired_table_rec/utils.py +++ b/wired_table_rec/utils.py @@ -1,4 +1,5 @@ # -*- encoding: utf-8 -*- +import math import traceback from io import BytesIO from pathlib import Path @@ -350,3 +351,47 @@ def _scale_size(size, scale): scale = (scale, scale) w, h = size return int(w * float(scale[0]) + 0.5), int(h * float(scale[1]) + 0.5) + + +class ImageOrientationCorrector: + """ + 对图片小角度(-90 - + 90度进行修正) + """ + + def __init__(self): + self.img_loader = LoadImage() + + def __call__(self, img: InputType): + img = self.img_loader(img) + # 取灰度 + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # 二值化 + gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1] + # 边缘检测 + edges = cv2.Canny(gray, 100, 250, apertureSize=3) + # 霍夫变换,摘自https://blog.csdn.net/feilong_csdn/article/details/81586322 + lines = cv2.HoughLines(edges, 1, np.pi / 180, 0) + for rho, theta in lines[0]: + a = np.cos(theta) + b = np.sin(theta) + x0 = a * rho + y0 = b * rho + x1 = int(x0 + 1000 * (-b)) + y1 = int(y0 + 1000 * (a)) + x2 = int(x0 - 1000 * (-b)) + y2 = int(y0 - 1000 * (a)) + if x1 == x2 or y1 == y2: + return img + else: + t = float(y2 - y1) / (x2 - x1) + # 得到角度后 + rotate_angle = math.degrees(math.atan(t)) + if rotate_angle > 45: + rotate_angle = -90 + rotate_angle + elif rotate_angle < -45: + rotate_angle = 90 + rotate_angle + # 旋转图像 + (h, w) = img.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, rotate_angle, 1.0) + return cv2.warpAffine(img, M, (w, h)) From 6b76462adfe6fb8c8d330b6408c6c11636e7c1b8 Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Thu, 19 Sep 2024 21:39:30 +0800 Subject: [PATCH 2/4] fix: adaptive for py 3.8 --- README.md | 13 +++++++-- lineless_table_rec/main.py | 4 +-- lineless_table_rec/utils_table_recover.py | 32 ++++++++++----------- setup_table_cls.py | 32 ++++++++++----------- tests/test_wired_table_rec.py | 2 +- wired_table_rec/main.py | 10 +++---- wired_table_rec/utils_table_recover.py | 35 ++++++++++++----------- 7 files changed, 70 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 0afd3eb..5e6ae93 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,16 @@ print(f"elasp: {elasp}") # # 可视化 ocr 识别框 # plot_rec_box(img_path, f"{output_dir}/ocr_box.jpg", ocr_res) ``` - +#### 偏移修正 +```python +import cv2 +img_path = f'tests/test_files/wired/squeeze_error.jpeg' +from wired_table_rec.utils import ImageOrientationCorrector +img_orientation_corrector = ImageOrientationCorrector() +img = cv2.imread(img_path) +img = img_orientation_corrector(img) +cv2.imwrite(f'img_rotated.jpg', img) +``` ## FAQ (Frequently Asked Questions) 1. **问:偏移的图片能够处理吗?** @@ -101,7 +110,7 @@ print(f"elasp: {elasp}") ### TODO List -- [ ] 识别前图片偏移修正 +- [ ] 识别前图片偏移修正(完成有线表格小角度偏移修正) - [ ] 增加数据集数量,增加更多评测对比 - [ ] 优化无线表格模型 diff --git a/lineless_table_rec/main.py b/lineless_table_rec/main.py index 62a4e6b..cf0116b 100644 --- a/lineless_table_rec/main.py +++ b/lineless_table_rec/main.py @@ -119,10 +119,10 @@ def __call__( def transform_res( self, - cell_box_det_map: dict[int, List[any]], + cell_box_det_map: Dict[int, List[any]], polygons: np.ndarray, logi_points: List[np.ndarray], - ) -> List[dict[str, any]]: + ) -> List[Dict[str, any]]: res = [] for i in range(len(polygons)): ocr_res_list = cell_box_det_map.get(i) diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils_table_recover.py index 670364b..c99b02d 100644 --- a/lineless_table_rec/utils_table_recover.py +++ b/lineless_table_rec/utils_table_recover.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com import os import random -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Set, Tuple import cv2 import numpy as np @@ -67,7 +67,7 @@ def compute_poly_iou(a: np.ndarray, b: np.ndarray) -> float: return float(inter_area) / union_area -def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: +def filter_duplicated_box(table_boxes: List[List[float]]) -> Set[int]: """ :param table_boxes: [[xmin,ymin,xmax,ymax]] :return: @@ -95,7 +95,9 @@ def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: return delete_idx -def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: +def calculate_iou( + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list] +) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -127,7 +129,7 @@ def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: def caculate_single_axis_iou( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x" + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x" ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -151,8 +153,8 @@ def caculate_single_axis_iou( def is_box_contained( - box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2 -) -> int | None: + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], threshold=0.2 +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -195,8 +197,8 @@ def is_box_contained( def is_single_axis_contained( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold=0.2 -) -> int | None: + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x", threhold=0.2 +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -228,8 +230,8 @@ def is_single_axis_contained( def sorted_ocr_boxes( - dt_boxes: np.ndarray | list, threhold: float = 0.2 -) -> tuple[np.ndarray | list, list[int]]: + dt_boxes: Union[np.ndarray, List], threhold: float = 0.2 +) -> Tuple[Union[np.ndarray, list], List[int]]: """ Sort text boxes in order from top to bottom, left to right args: @@ -266,9 +268,7 @@ def sorted_ocr_boxes( return _boxes, indices -def gather_ocr_list_by_row( - ocr_list: list[list[list[float], str]], thehold: float = 0.2 -) -> list[list[list[float], str]]: +def gather_ocr_list_by_row(ocr_list: List[Any], thehold: float = 0.2) -> List[Any]: """ :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] :return: @@ -305,12 +305,12 @@ def gather_ocr_list_by_row( return ocr_list -def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: +def box_4_1_poly_to_box_4_2(poly_box: Union[np.ndarray, list]) -> List[List[float]]: xmin, ymin, xmax, ymax = tuple(poly_box) return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] -def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]: +def box_4_2_poly_to_box_4_1(poly_box: Union[np.ndarray, list]) -> List[float]: """ 将poly_box转换为box_4_1 :param poly_box: @@ -407,7 +407,7 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd def plot_html_table( - logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]] + logi_points: Union[np.ndarray, list], cell_box_map: Dict[int, List[str]] ) -> str: # 初始化最大行数和列数 max_row = 0 diff --git a/setup_table_cls.py b/setup_table_cls.py index 70df236..7a8e87d 100644 --- a/setup_table_cls.py +++ b/setup_table_cls.py @@ -1,12 +1,13 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import sys from pathlib import Path from typing import List, Union import setuptools -# from get_pypi_latest_version import GetPyPiLatestVersion +from get_pypi_latest_version import GetPyPiLatestVersion def read_txt(txt_path: Union[Path, str]) -> List[str]: @@ -17,21 +18,20 @@ def read_txt(txt_path: Union[Path, str]) -> List[str]: MODULE_NAME = "table_cls" -# obtainer = GetPyPiLatestVersion() -# try: -# latest_version = obtainer(MODULE_NAME) -# except Exception: -# latest_version = "0.0.0" -# -# VERSION_NUM = obtainer.version_add_one(latest_version) -VERSION_NUM = "1.0.0" - -# if len(sys.argv) > 2: -# match_str = " ".join(sys.argv[2:]) -# matched_versions = obtainer.extract_version(match_str) -# if matched_versions: -# VERSION_NUM = matched_versions -# sys.argv = sys.argv[:2] +obtainer = GetPyPiLatestVersion() +try: + latest_version = obtainer(MODULE_NAME) +except Exception: + latest_version = "0.0.0" + +VERSION_NUM = obtainer.version_add_one(latest_version) + +if len(sys.argv) > 2: + match_str = " ".join(sys.argv[2:]) + matched_versions = obtainer.extract_version(match_str) + if matched_versions: + VERSION_NUM = matched_versions +sys.argv = sys.argv[:2] setuptools.setup( name=MODULE_NAME, diff --git a/tests/test_wired_table_rec.py b/tests/test_wired_table_rec.py index 37d46e0..b4f263a 100644 --- a/tests/test_wired_table_rec.py +++ b/tests/test_wired_table_rec.py @@ -41,7 +41,7 @@ def test_squeeze_bug(): ocr_result, _ = ocr_engine(img_path) table_str, *_ = table_recog(str(img_path), ocr_result) td_nums = get_td_nums(table_str) - assert td_nums == 228 + assert td_nums == 291 @pytest.mark.parametrize( diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index 0b0b784..c220f97 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -7,7 +7,7 @@ import time import traceback from pathlib import Path -from typing import List, Optional, Tuple, Union, Dict +from typing import List, Optional, Tuple, Union, Dict, Any import numpy as np import cv2 @@ -50,7 +50,7 @@ def __call__( self, img: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, - ) -> Tuple[str, float, list]: + ) -> Tuple[str, float, Any, Any, Any]: if self.ocr is None and ocr_result is None: raise ValueError( "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." @@ -109,10 +109,10 @@ def __call__( def transform_res( self, - cell_box_det_map: dict[int, List[any]], + cell_box_det_map: Dict[int, List[any]], polygons: np.ndarray, logi_points: List[np.ndarray], - ) -> List[dict[str, any]]: + ) -> List[Dict[str, any]]: res = [] for i in range(len(polygons)): ocr_res_list = cell_box_det_map.get(i) @@ -152,7 +152,7 @@ def re_rec( img: np.ndarray, sorted_polygons: np.ndarray, cell_box_map: Dict[int, List[str]], - ) -> Dict[int, List[any]]: + ) -> Dict[int, List[Any]]: """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" # for i in range(sorted_polygons.shape[0]): diff --git a/wired_table_rec/utils_table_recover.py b/wired_table_rec/utils_table_recover.py index 4af788c..447e9cc 100644 --- a/wired_table_rec/utils_table_recover.py +++ b/wired_table_rec/utils_table_recover.py @@ -3,7 +3,7 @@ # @Contact: liekkaskono@163.com import os import random -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, Set, Tuple import cv2 import numpy as np @@ -36,7 +36,9 @@ def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray: return np.array(_boxes) -def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: +def calculate_iou( + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list] +) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -68,7 +70,7 @@ def calculate_iou(box1: list | np.ndarray, box2: list | np.ndarray) -> float: def caculate_single_axis_iou( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x" + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x" ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -92,8 +94,8 @@ def caculate_single_axis_iou( def is_box_contained( - box1: list | np.ndarray, box2: list | np.ndarray, threshold=0.2 -) -> int | None: + box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], threshold=0.2 +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -136,8 +138,11 @@ def is_box_contained( def is_single_axis_contained( - box1: list | np.ndarray, box2: list | np.ndarray, axis="x", threhold: float = 0.2 -) -> int | None: + box1: Union[np.ndarray, list], + box2: Union[np.ndarray, list], + axis="x", + threhold: float = 0.2, +) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] :param box2: Iterable [xmin,ymin,xmax,ymax] @@ -168,7 +173,7 @@ def is_single_axis_contained( return None -def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: +def filter_duplicated_box(table_boxes: List[List[float]]) -> Set[int]: """ :param table_boxes: [[xmin,ymin,xmax,ymax]] :return: @@ -197,8 +202,8 @@ def filter_duplicated_box(table_boxes: list[list[float]]) -> set[int]: def sorted_ocr_boxes( - dt_boxes: np.ndarray | list, threhold: float = 0.2 -) -> tuple[np.ndarray | list, list[int]]: + dt_boxes: Union[np.ndarray, list], threhold: float = 0.2 +) -> Tuple[Union[np.ndarray, list], List[int]]: """ Sort text boxes in order from top to bottom, left to right args: @@ -312,12 +317,12 @@ def plot_rec_box(img_path, output_path, sorted_polygons): cv2.imwrite(output_path, img) -def box_4_1_poly_to_box_4_2(poly_box: list | np.ndarray) -> list[list[float]]: +def box_4_1_poly_to_box_4_2(poly_box: Union[list, np.ndarray]) -> List[List[float]]: xmin, ymin, xmax, ymax = tuple(poly_box) return [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]] -def box_4_2_poly_to_box_4_1(poly_box: list | np.ndarray) -> list[float]: +def box_4_2_poly_to_box_4_1(poly_box: Union[list, np.ndarray]) -> List[Any]: """ 将poly_box转换为box_4_1 :param poly_box: @@ -357,9 +362,7 @@ def match_ocr_cell(dt_rec_boxes: List[List[Union[Any, str]]], pred_bboxes: np.nd return matched, not_match_orc_boxes -def gather_ocr_list_by_row( - ocr_list: list[list[list[float], str]], threhold: float = 0.2 -) -> list[list[list[float], str]]: +def gather_ocr_list_by_row(ocr_list: List[Any], threhold: float = 0.2) -> List[Any]: """ :param ocr_list: [[[xmin,ymin,xmax,ymax], text]] :return: @@ -555,7 +558,7 @@ def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray): def plot_html_table( - logi_points: np.ndarray | list, cell_box_map: Dict[int, List[str]] + logi_points: Union[Union[np.ndarray, list]], cell_box_map: Dict[int, List[str]] ) -> str: # 初始化最大行数和列数 max_row = 0 From a7fc1da04c81fcfef7f0f7706ffb20fab6cac3fe Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Thu, 19 Sep 2024 22:43:05 +0800 Subject: [PATCH 3/4] fix: remove useless code --- wired_table_rec/table_recover.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/wired_table_rec/table_recover.py b/wired_table_rec/table_recover.py index 16e409c..afb2c2d 100644 --- a/wired_table_rec/table_recover.py +++ b/wired_table_rec/table_recover.py @@ -114,38 +114,6 @@ def get_benchmark_rows( ) -> Tuple[np.ndarray, List[float], int]: leftmost_cell_idxs = [v[0] for v in rows.values()] benchmark_x = polygons[leftmost_cell_idxs][:, 0, 1] - # 有线表格模型精度足够,不进行工程化修正,避免更多未知问题 - # theta = 15 - # 遍历其他所有的框,按照y轴进行区间划分 - # range_res = {} - # for cur_idx, cur_box in enumerate(polygons): - # # fix cur_idx in benchmark_x - # if cur_idx in leftmost_cell_idxs: - # continue - # - # cur_y = cur_box[0, 1] - # - # start_idx, end_idx = None, None - # for i, v in enumerate(benchmark_x): - # if cur_y - theta <= v <= cur_y + theta: - # break - # - # if cur_y > v: - # start_idx = i - # continue - # - # if cur_y < v: - # end_idx = i - # break - # - # range_res[cur_idx] = [start_idx, end_idx] - # - # sorted_res = dict(sorted(range_res.items(), key=lambda x: x[0], reverse=True)) - # for k, v in sorted_res.items(): - # if not all(v): - # continue - # - # benchmark_x = np.insert(benchmark_x, v[1], polygons[k][0, 1]) each_row_widths = (benchmark_x[1:] - benchmark_x[:-1]).tolist() From dcbc8c21de4ec5bd623887d87b4a60cd1fa290fc Mon Sep 17 00:00:00 2001 From: Jokcer <519548295@qq.com> Date: Thu, 19 Sep 2024 22:51:25 +0800 Subject: [PATCH 4/4] fix: replace list type with typing.List --- lineless_table_rec/utils_table_recover.py | 6 +++--- wired_table_rec/utils_table_recover.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lineless_table_rec/utils_table_recover.py b/lineless_table_rec/utils_table_recover.py index c99b02d..30921cb 100644 --- a/lineless_table_rec/utils_table_recover.py +++ b/lineless_table_rec/utils_table_recover.py @@ -96,7 +96,7 @@ def filter_duplicated_box(table_boxes: List[List[float]]) -> Set[int]: def calculate_iou( - box1: Union[np.ndarray, list], box2: Union[np.ndarray, list] + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List] ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -129,7 +129,7 @@ def calculate_iou( def caculate_single_axis_iou( - box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x" + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], axis="x" ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -153,7 +153,7 @@ def caculate_single_axis_iou( def is_box_contained( - box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], threshold=0.2 + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2 ) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] diff --git a/wired_table_rec/utils_table_recover.py b/wired_table_rec/utils_table_recover.py index 447e9cc..a12e6e9 100644 --- a/wired_table_rec/utils_table_recover.py +++ b/wired_table_rec/utils_table_recover.py @@ -37,7 +37,7 @@ def sorted_boxes(dt_boxes: np.ndarray) -> np.ndarray: def calculate_iou( - box1: Union[np.ndarray, list], box2: Union[np.ndarray, list] + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List] ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -70,7 +70,7 @@ def calculate_iou( def caculate_single_axis_iou( - box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], axis="x" + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], axis="x" ) -> float: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -94,7 +94,7 @@ def caculate_single_axis_iou( def is_box_contained( - box1: Union[np.ndarray, list], box2: Union[np.ndarray, list], threshold=0.2 + box1: Union[np.ndarray, List], box2: Union[np.ndarray, List], threshold=0.2 ) -> Union[int, None]: """ :param box1: Iterable [xmin,ymin,xmax,ymax] @@ -138,8 +138,8 @@ def is_box_contained( def is_single_axis_contained( - box1: Union[np.ndarray, list], - box2: Union[np.ndarray, list], + box1: Union[np.ndarray, List], + box2: Union[np.ndarray, List], axis="x", threhold: float = 0.2, ) -> Union[int, None]: @@ -558,7 +558,7 @@ def is_inclusive_each_other(box1: np.ndarray, box2: np.ndarray): def plot_html_table( - logi_points: Union[Union[np.ndarray, list]], cell_box_map: Dict[int, List[str]] + logi_points: Union[Union[np.ndarray, List]], cell_box_map: Dict[int, List[str]] ) -> str: # 初始化最大行数和列数 max_row = 0