From df6f8fcf8f7b7d938ad0300155ef5752bf4e1612 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Fri, 3 May 2024 14:14:05 -0700 Subject: [PATCH] Improve table recognition and equation insertion --- marker/cleaners/fontstyle.py | 4 ++ marker/cleaners/table.py | 106 ++++++++++++++++++++++++++------ marker/equations/equations.py | 46 +++++++++++--- marker/ocr/heuristics.py | 2 +- marker/ocr/recognition.py | 4 +- marker/pdf/extract_text.py | 9 ++- marker/postprocessors/images.py | 0 marker/schema/block.py | 12 ++++ 8 files changed, 149 insertions(+), 34 deletions(-) create mode 100644 marker/postprocessors/images.py diff --git a/marker/cleaners/fontstyle.py b/marker/cleaners/fontstyle.py index a92d8bd6..2f4a6185 100644 --- a/marker/cleaners/fontstyle.py +++ b/marker/cleaners/fontstyle.py @@ -20,6 +20,10 @@ def find_bold_italic(pages: List[Page], bold_min_weight=550): span.italic = True font_weights.append(span.font_weight) + + if len(font_weights) == 0: + return + font_weights = np.array(font_weights) bold_thresh = np.percentile(font_weights, 90) bold_thresh_lower = np.percentile(font_weights, 75) diff --git a/marker/cleaners/table.py b/marker/cleaners/table.py index bfb0e200..2a5f36ac 100644 --- a/marker/cleaners/table.py +++ b/marker/cleaners/table.py @@ -15,10 +15,17 @@ def replace_dots(text): return text +def replace_newlines(text): + # Replace all newlines + newline_pattern = re.compile(r'[\r\n]+') + return newline_pattern.sub(' ', text.strip()) + + def get_table_surya(page, table_box, y_tol=.005) -> List[List[str]]: table_rows = [] - row_y_coord = None table_row = [] + x_position = None + y_position = None for block_idx, block in enumerate(page.blocks): for line_idx, line in enumerate(block.lines): line_bbox = line.bbox @@ -26,30 +33,81 @@ def get_table_surya(page, table_box, y_tol=.005) -> List[List[str]]: if intersect_pct < .5 or len(line.spans) == 0: continue normed_y_start = line_bbox[1] / page.height - if row_y_coord is None or abs(normed_y_start - row_y_coord) < y_tol: - table_row.extend([s.text for s in line.spans]) + normed_x_start = line_bbox[0] / page.width + normed_x_end = line_bbox[2] / page.width + + cells = [[s.bbox, s.text] for s in line.spans] + if x_position is None or (normed_x_start > x_position and abs(normed_y_start - y_position) < y_tol): + # Same row + table_row.extend(cells) else: - table_rows.append(table_row) - table_row = [s.text for s in line.spans] - row_y_coord = normed_y_start + # New row + if len(table_row) > 0: + table_rows.append(table_row) + table_row = cells + y_position = normed_y_start + x_position = normed_x_end if len(table_row) > 0: table_rows.append(table_row) + table_rows = assign_cells_to_columns(table_rows) return table_rows -def get_table_pdftext(page: Page, table_box) -> List[List[str]]: +def assign_cells_to_columns(rows, round_factor=4, tolerance=4): + left_edges = [] + right_edges = [] + centers = [] + + for row in rows: + for cell in row: + left_edges.append(cell[0][0] / round_factor * round_factor) + right_edges.append(cell[0][2] / round_factor * round_factor) + centers.append((cell[0][0] + cell[0][2]) / 2 * round_factor / round_factor) + + unique_left = sorted(list(set(left_edges))) + unique_right = sorted(list(set(right_edges))) + unique_center = sorted(list(set(centers))) + + # Find list with minimum length + separators = min([unique_left, unique_right, unique_center], key=len) + + new_rows = [] + for row in rows: + new_row = {} + last_col_index = -1 + for cell in row: + left_edge = cell[0][0] + column_index = -1 + for i, separator in enumerate(separators): + if left_edge - tolerance < separator and last_col_index < i: + column_index = i + break + if column_index == -1: + column_index = cell[0][0] # Assign a new column + new_row[column_index] = cell[1] + last_col_index = column_index + + flat_row = [cell[1] for cell in sorted(new_row.items())] + min_column_index = min(new_row.keys()) + flat_row = [""] * min_column_index + flat_row + new_rows.append(flat_row) + + return new_rows + + +def get_table_pdftext(page: Page, table_box, space_tol=.01) -> List[List[str]]: page_width = page.width table_rows = [] + table_cell = "" + cell_bbox = None + prev_end = None + table_row = [] for block_idx, block in enumerate(page.char_blocks): for line_idx, line in enumerate(block["lines"]): line_bbox = line["bbox"] intersect_pct = box_intersection_pct(line_bbox, table_box) if intersect_pct < .5: continue - prev_end = None - table_row = [] - table_cell = "" - cell_bbox = None for span in line["spans"]: for char in span["chars"]: x_start, y_start, x_end, y_end = char["bbox"] @@ -60,18 +118,28 @@ def get_table_pdftext(page: Page, table_box) -> List[List[str]]: x_start /= page_width x_end /= page_width - if prev_end is None or x_start - prev_end < .01: + cell_content = replace_dots(replace_newlines(table_cell)) + if prev_end is None or abs(x_start - prev_end) < space_tol: # Check if we are in the same cell table_cell += char["char"] - else: - table_row.append(replace_dots(table_cell.strip())) + elif x_start > prev_end - space_tol: # Check if we are on the same line + if len(table_cell) > 0: + table_row.append((cell_bbox, cell_content)) table_cell = char["char"] cell_bbox = char["bbox"] + else: # New line and cell + if len(table_cell) > 0: + table_row.append((cell_bbox, cell_content)) + table_cell = char["char"] + cell_bbox = char["bbox"] + if len(table_row) > 0: + table_rows.append(table_row) + table_row = [] prev_end = x_end - if len(table_cell) > 0: - table_row.append(replace_dots(table_cell.strip())) - table_cell = "" - if len(table_row) > 0: - table_rows.append(table_row) + if len(table_cell) > 0: + table_row.append((cell_bbox, replace_dots(replace_newlines(table_cell)))) + if len(table_row) > 0: + table_rows.append(table_row) + table_rows = assign_cells_to_columns(table_rows) return table_rows diff --git a/marker/equations/equations.py b/marker/equations/equations.py index d0246f60..da23b136 100644 --- a/marker/equations/equations.py +++ b/marker/equations/equations.py @@ -7,7 +7,7 @@ from marker.equations.inference import get_total_texify_tokens, get_latex_batched from marker.schema.bbox import rescale_bbox from marker.schema.page import Page -from marker.schema.block import Line, Span, Block, bbox_from_lines +from marker.schema.block import Line, Span, Block, bbox_from_lines, split_block_lines from marker.settings import settings @@ -28,11 +28,7 @@ def find_equation_blocks(page, processor): equation_lines[region_idx].append(line) if region_idx not in insert_points: - # Insert before the block if line is at the beginning of the block, otherwise after the block - if line_idx <= len(block.lines) // 2: - insert_points[region_idx] = block_idx - else: - insert_points[region_idx] = block_idx + 1 + insert_points[region_idx] = (block_idx, line_idx) block_lines_to_remove = defaultdict(set) for region_idx, equation_region in enumerate(equation_regions): @@ -44,8 +40,13 @@ def find_equation_blocks(page, processor): equation_bbox = bbox_from_lines(equation_block) total_tokens = get_total_texify_tokens(block_text, processor) - selected_blocks = (equation_insert, total_tokens, block_text, equation_bbox) + equation_insert_line_idx = equation_insert[1] + equation_insert_line_idx -= len( + [x for x in lines_to_remove[region_idx] if x[0] == equation_insert[0] and x[1] < equation_insert[1]]) + + selected_blocks = [equation_insert[0], equation_insert_line_idx, total_tokens, block_text, equation_bbox] if total_tokens < settings.TEXIFY_MODEL_MAX: + # Account for the lines we're about to remove for item in lines_to_remove[region_idx]: block_lines_to_remove[item[0]].add(item[1]) equation_blocks.append(selected_blocks) @@ -58,12 +59,19 @@ def find_equation_blocks(page, processor): return equation_blocks +def increment_insert_points(page_equation_blocks, insert_block_idx, insert_count): + for idx, (block_idx, line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks): + if block_idx >= insert_block_idx: + page_equation_blocks[idx][0] += insert_count + + def insert_latex_block(page_blocks: Page, page_equation_blocks, predictions, pnum, processor): converted_spans = [] idx = 0 success_count = 0 fail_count = 0 - for block_number, (insert_point, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks): + total_inserted = 0 + for block_number, (insert_block_idx, insert_line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks): latex_text = predictions[block_number] conditions = [ get_total_texify_tokens(latex_text, processor) < settings.TEXIFY_MODEL_MAX, # Make sure we didn't get to the overall token max, indicates run-on @@ -97,7 +105,25 @@ def insert_latex_block(page_blocks: Page, page_equation_blocks, predictions, pnu new_block.lines[0].spans[0].text = latex_text converted_spans.append(deepcopy(new_block.lines[0].spans[0])) - page_blocks.blocks.insert(insert_point, new_block) + # Add in the new LaTeX block + if insert_line_idx == 0: + page_blocks.blocks.insert(insert_block_idx, new_block) + increment_insert_points(page_equation_blocks, insert_block_idx, 1) + elif insert_line_idx >= len(page_blocks.blocks[insert_block_idx].lines): + page_blocks.blocks.insert(insert_block_idx + 1, new_block) + increment_insert_points(page_equation_blocks, insert_block_idx + 1, 1) + else: + new_blocks = [] + for block_idx, block in enumerate(page_blocks.blocks): + if block_idx == insert_block_idx: + split_block = split_block_lines(block, insert_line_idx) + new_blocks.append(split_block[0]) + new_blocks.append(new_block) + new_blocks.append(split_block[1]) + increment_insert_points(page_equation_blocks, insert_block_idx, 2) + else: + new_blocks.append(block) + page_blocks.blocks = new_blocks return success_count, fail_count, converted_spans @@ -117,7 +143,7 @@ def replace_equations(doc, pages: List[Page], texify_model, batch_size=settings. token_counts = [] for page_idx, page_equation_blocks in enumerate(equation_blocks): page_obj = doc[page_idx] - for equation_idx, (insert_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks): + for equation_idx, (insert_block_idx, insert_line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks): png_image = get_equation_image(page_obj, pages[page_idx], equation_bbox) images.append(png_image) diff --git a/marker/ocr/heuristics.py b/marker/ocr/heuristics.py index ffe6e422..b83d5566 100644 --- a/marker/ocr/heuristics.py +++ b/marker/ocr/heuristics.py @@ -63,7 +63,7 @@ def detected_line_coverage(page: Page, intersect_thresh=.5, detection_thresh=.6) total_intersection = 0 for block in page.blocks: for line in block.lines: - intersection_pct = box_intersection_pct(detected_bbox, line.bbox) + intersection_pct = box_intersection_pct(line.bbox, detected_bbox) total_intersection += intersection_pct if total_intersection > intersect_thresh: found_lines += 1 diff --git a/marker/ocr/recognition.py b/marker/ocr/recognition.py index d62624b8..535f6507 100644 --- a/marker/ocr/recognition.py +++ b/marker/ocr/recognition.py @@ -120,8 +120,8 @@ def _tesseract_recognition(in_pdf, langs: List[str]) -> Optional[Page]: out_pdf, language=langs[0], output_type="pdf", - redo_ocr=None if settings.OCR_ALL_PAGES else True, - force_ocr=True if settings.OCR_ALL_PAGES else None, + redo_ocr=None, + force_ocr=True, progress_bar=False, optimize=False, fast_web_view=1e6, diff --git a/marker/pdf/extract_text.py b/marker/pdf/extract_text.py index 1512643c..56871d6a 100644 --- a/marker/pdf/extract_text.py +++ b/marker/pdf/extract_text.py @@ -25,7 +25,7 @@ def pdftext_format_to_blocks(page, pnum: int) -> Page: block_text = s["text"].rstrip("\n") block_text = block_text.replace("-\n", "") # Remove hyphenated line breaks span_obj = Span( - text=block_text.rstrip("\n"), # Remove end of line newlines, not spaces + text=block_text, # Remove end of line newlines, not spaces bbox=s["bbox"], span_id=f"{pnum}_{span_id}", font=f"{s['font']['name']}_{font_flags_decomposer(s['font']['flags'])}", # Add font flags to end of font @@ -49,10 +49,15 @@ def pdftext_format_to_blocks(page, pnum: int) -> Page: # Only select blocks with lines if len(block_lines) > 0: page_blocks.append(block_obj) + + page_bbox = page["bbox"] + page_width = abs(page_bbox[2] - page_bbox[0]) + page_height = abs(page_bbox[3] - page_bbox[1]) + page_bbox = [0, 0, page_width, page_height] out_page = Page( blocks=page_blocks, pnum=page["page"], - bbox=page["bbox"], + bbox=page_bbox, rotation=page["rotation"], char_blocks=page["blocks"] ) diff --git a/marker/postprocessors/images.py b/marker/postprocessors/images.py new file mode 100644 index 00000000..e69de29b diff --git a/marker/schema/block.py b/marker/schema/block.py index df4f90e8..1220b698 100644 --- a/marker/schema/block.py +++ b/marker/schema/block.py @@ -86,3 +86,15 @@ def bbox_from_lines(lines: List[Line]): max_x = max([line.bbox[2] for line in lines]) max_y = max([line.bbox[3] for line in lines]) return [min_x, min_y, max_x, max_y] + + +def split_block_lines(block: Block, split_line_idx: int): + new_blocks = [] + if split_line_idx >= len(block.lines): + return [block] + elif split_line_idx == 0: + return [block] + else: + new_blocks.append(Block(lines=block.lines[:split_line_idx], bbox=bbox_from_lines(block.lines[:split_line_idx]), pnum=block.pnum)) + new_blocks.append(Block(lines=block.lines[split_line_idx:], bbox=bbox_from_lines(block.lines[split_line_idx:]), pnum=block.pnum)) + return new_blocks