Skip to content

Commit

Permalink
Improve table recognition and equation insertion
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 3, 2024
1 parent 4786f17 commit df6f8fc
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 34 deletions.
4 changes: 4 additions & 0 deletions marker/cleaners/fontstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def find_bold_italic(pages: List[Page], bold_min_weight=550):
span.italic = True

font_weights.append(span.font_weight)

if len(font_weights) == 0:
return

font_weights = np.array(font_weights)
bold_thresh = np.percentile(font_weights, 90)
bold_thresh_lower = np.percentile(font_weights, 75)
Expand Down
106 changes: 87 additions & 19 deletions marker/cleaners/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,99 @@ def replace_dots(text):
return text


def replace_newlines(text):
# Replace all newlines
newline_pattern = re.compile(r'[\r\n]+')
return newline_pattern.sub(' ', text.strip())


def get_table_surya(page, table_box, y_tol=.005) -> List[List[str]]:
table_rows = []
row_y_coord = None
table_row = []
x_position = None
y_position = None
for block_idx, block in enumerate(page.blocks):
for line_idx, line in enumerate(block.lines):
line_bbox = line.bbox
intersect_pct = box_intersection_pct(line_bbox, table_box)
if intersect_pct < .5 or len(line.spans) == 0:
continue
normed_y_start = line_bbox[1] / page.height
if row_y_coord is None or abs(normed_y_start - row_y_coord) < y_tol:
table_row.extend([s.text for s in line.spans])
normed_x_start = line_bbox[0] / page.width
normed_x_end = line_bbox[2] / page.width

cells = [[s.bbox, s.text] for s in line.spans]
if x_position is None or (normed_x_start > x_position and abs(normed_y_start - y_position) < y_tol):
# Same row
table_row.extend(cells)
else:
table_rows.append(table_row)
table_row = [s.text for s in line.spans]
row_y_coord = normed_y_start
# New row
if len(table_row) > 0:
table_rows.append(table_row)
table_row = cells
y_position = normed_y_start
x_position = normed_x_end
if len(table_row) > 0:
table_rows.append(table_row)
table_rows = assign_cells_to_columns(table_rows)
return table_rows


def get_table_pdftext(page: Page, table_box) -> List[List[str]]:
def assign_cells_to_columns(rows, round_factor=4, tolerance=4):
left_edges = []
right_edges = []
centers = []

for row in rows:
for cell in row:
left_edges.append(cell[0][0] / round_factor * round_factor)
right_edges.append(cell[0][2] / round_factor * round_factor)
centers.append((cell[0][0] + cell[0][2]) / 2 * round_factor / round_factor)

unique_left = sorted(list(set(left_edges)))
unique_right = sorted(list(set(right_edges)))
unique_center = sorted(list(set(centers)))

# Find list with minimum length
separators = min([unique_left, unique_right, unique_center], key=len)

new_rows = []
for row in rows:
new_row = {}
last_col_index = -1
for cell in row:
left_edge = cell[0][0]
column_index = -1
for i, separator in enumerate(separators):
if left_edge - tolerance < separator and last_col_index < i:
column_index = i
break
if column_index == -1:
column_index = cell[0][0] # Assign a new column
new_row[column_index] = cell[1]
last_col_index = column_index

flat_row = [cell[1] for cell in sorted(new_row.items())]
min_column_index = min(new_row.keys())
flat_row = [""] * min_column_index + flat_row
new_rows.append(flat_row)

return new_rows


def get_table_pdftext(page: Page, table_box, space_tol=.01) -> List[List[str]]:
page_width = page.width
table_rows = []
table_cell = ""
cell_bbox = None
prev_end = None
table_row = []
for block_idx, block in enumerate(page.char_blocks):
for line_idx, line in enumerate(block["lines"]):
line_bbox = line["bbox"]
intersect_pct = box_intersection_pct(line_bbox, table_box)
if intersect_pct < .5:
continue
prev_end = None
table_row = []
table_cell = ""
cell_bbox = None
for span in line["spans"]:
for char in span["chars"]:
x_start, y_start, x_end, y_end = char["bbox"]
Expand All @@ -60,18 +118,28 @@ def get_table_pdftext(page: Page, table_box) -> List[List[str]]:

x_start /= page_width
x_end /= page_width
if prev_end is None or x_start - prev_end < .01:
cell_content = replace_dots(replace_newlines(table_cell))
if prev_end is None or abs(x_start - prev_end) < space_tol: # Check if we are in the same cell
table_cell += char["char"]
else:
table_row.append(replace_dots(table_cell.strip()))
elif x_start > prev_end - space_tol: # Check if we are on the same line
if len(table_cell) > 0:
table_row.append((cell_bbox, cell_content))
table_cell = char["char"]
cell_bbox = char["bbox"]
else: # New line and cell
if len(table_cell) > 0:
table_row.append((cell_bbox, cell_content))
table_cell = char["char"]
cell_bbox = char["bbox"]
if len(table_row) > 0:
table_rows.append(table_row)
table_row = []
prev_end = x_end
if len(table_cell) > 0:
table_row.append(replace_dots(table_cell.strip()))
table_cell = ""
if len(table_row) > 0:
table_rows.append(table_row)
if len(table_cell) > 0:
table_row.append((cell_bbox, replace_dots(replace_newlines(table_cell))))
if len(table_row) > 0:
table_rows.append(table_row)
table_rows = assign_cells_to_columns(table_rows)
return table_rows


Expand Down
46 changes: 36 additions & 10 deletions marker/equations/equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from marker.equations.inference import get_total_texify_tokens, get_latex_batched
from marker.schema.bbox import rescale_bbox
from marker.schema.page import Page
from marker.schema.block import Line, Span, Block, bbox_from_lines
from marker.schema.block import Line, Span, Block, bbox_from_lines, split_block_lines
from marker.settings import settings


Expand All @@ -28,11 +28,7 @@ def find_equation_blocks(page, processor):
equation_lines[region_idx].append(line)

if region_idx not in insert_points:
# Insert before the block if line is at the beginning of the block, otherwise after the block
if line_idx <= len(block.lines) // 2:
insert_points[region_idx] = block_idx
else:
insert_points[region_idx] = block_idx + 1
insert_points[region_idx] = (block_idx, line_idx)

block_lines_to_remove = defaultdict(set)
for region_idx, equation_region in enumerate(equation_regions):
Expand All @@ -44,8 +40,13 @@ def find_equation_blocks(page, processor):
equation_bbox = bbox_from_lines(equation_block)

total_tokens = get_total_texify_tokens(block_text, processor)
selected_blocks = (equation_insert, total_tokens, block_text, equation_bbox)
equation_insert_line_idx = equation_insert[1]
equation_insert_line_idx -= len(
[x for x in lines_to_remove[region_idx] if x[0] == equation_insert[0] and x[1] < equation_insert[1]])

selected_blocks = [equation_insert[0], equation_insert_line_idx, total_tokens, block_text, equation_bbox]
if total_tokens < settings.TEXIFY_MODEL_MAX:
# Account for the lines we're about to remove
for item in lines_to_remove[region_idx]:
block_lines_to_remove[item[0]].add(item[1])
equation_blocks.append(selected_blocks)
Expand All @@ -58,12 +59,19 @@ def find_equation_blocks(page, processor):
return equation_blocks


def increment_insert_points(page_equation_blocks, insert_block_idx, insert_count):
for idx, (block_idx, line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
if block_idx >= insert_block_idx:
page_equation_blocks[idx][0] += insert_count


def insert_latex_block(page_blocks: Page, page_equation_blocks, predictions, pnum, processor):
converted_spans = []
idx = 0
success_count = 0
fail_count = 0
for block_number, (insert_point, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
total_inserted = 0
for block_number, (insert_block_idx, insert_line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
latex_text = predictions[block_number]
conditions = [
get_total_texify_tokens(latex_text, processor) < settings.TEXIFY_MODEL_MAX, # Make sure we didn't get to the overall token max, indicates run-on
Expand Down Expand Up @@ -97,7 +105,25 @@ def insert_latex_block(page_blocks: Page, page_equation_blocks, predictions, pnu
new_block.lines[0].spans[0].text = latex_text
converted_spans.append(deepcopy(new_block.lines[0].spans[0]))

page_blocks.blocks.insert(insert_point, new_block)
# Add in the new LaTeX block
if insert_line_idx == 0:
page_blocks.blocks.insert(insert_block_idx, new_block)
increment_insert_points(page_equation_blocks, insert_block_idx, 1)
elif insert_line_idx >= len(page_blocks.blocks[insert_block_idx].lines):
page_blocks.blocks.insert(insert_block_idx + 1, new_block)
increment_insert_points(page_equation_blocks, insert_block_idx + 1, 1)
else:
new_blocks = []
for block_idx, block in enumerate(page_blocks.blocks):
if block_idx == insert_block_idx:
split_block = split_block_lines(block, insert_line_idx)
new_blocks.append(split_block[0])
new_blocks.append(new_block)
new_blocks.append(split_block[1])
increment_insert_points(page_equation_blocks, insert_block_idx, 2)
else:
new_blocks.append(block)
page_blocks.blocks = new_blocks

return success_count, fail_count, converted_spans

Expand All @@ -117,7 +143,7 @@ def replace_equations(doc, pages: List[Page], texify_model, batch_size=settings.
token_counts = []
for page_idx, page_equation_blocks in enumerate(equation_blocks):
page_obj = doc[page_idx]
for equation_idx, (insert_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
for equation_idx, (insert_block_idx, insert_line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
png_image = get_equation_image(page_obj, pages[page_idx], equation_bbox)

images.append(png_image)
Expand Down
2 changes: 1 addition & 1 deletion marker/ocr/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def detected_line_coverage(page: Page, intersect_thresh=.5, detection_thresh=.6)
total_intersection = 0
for block in page.blocks:
for line in block.lines:
intersection_pct = box_intersection_pct(detected_bbox, line.bbox)
intersection_pct = box_intersection_pct(line.bbox, detected_bbox)
total_intersection += intersection_pct
if total_intersection > intersect_thresh:
found_lines += 1
Expand Down
4 changes: 2 additions & 2 deletions marker/ocr/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def _tesseract_recognition(in_pdf, langs: List[str]) -> Optional[Page]:
out_pdf,
language=langs[0],
output_type="pdf",
redo_ocr=None if settings.OCR_ALL_PAGES else True,
force_ocr=True if settings.OCR_ALL_PAGES else None,
redo_ocr=None,
force_ocr=True,
progress_bar=False,
optimize=False,
fast_web_view=1e6,
Expand Down
9 changes: 7 additions & 2 deletions marker/pdf/extract_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def pdftext_format_to_blocks(page, pnum: int) -> Page:
block_text = s["text"].rstrip("\n")
block_text = block_text.replace("-\n", "") # Remove hyphenated line breaks
span_obj = Span(
text=block_text.rstrip("\n"), # Remove end of line newlines, not spaces
text=block_text, # Remove end of line newlines, not spaces
bbox=s["bbox"],
span_id=f"{pnum}_{span_id}",
font=f"{s['font']['name']}_{font_flags_decomposer(s['font']['flags'])}", # Add font flags to end of font
Expand All @@ -49,10 +49,15 @@ def pdftext_format_to_blocks(page, pnum: int) -> Page:
# Only select blocks with lines
if len(block_lines) > 0:
page_blocks.append(block_obj)

page_bbox = page["bbox"]
page_width = abs(page_bbox[2] - page_bbox[0])
page_height = abs(page_bbox[3] - page_bbox[1])
page_bbox = [0, 0, page_width, page_height]
out_page = Page(
blocks=page_blocks,
pnum=page["page"],
bbox=page["bbox"],
bbox=page_bbox,
rotation=page["rotation"],
char_blocks=page["blocks"]
)
Expand Down
Empty file added marker/postprocessors/images.py
Empty file.
12 changes: 12 additions & 0 deletions marker/schema/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,15 @@ def bbox_from_lines(lines: List[Line]):
max_x = max([line.bbox[2] for line in lines])
max_y = max([line.bbox[3] for line in lines])
return [min_x, min_y, max_x, max_y]


def split_block_lines(block: Block, split_line_idx: int):
new_blocks = []
if split_line_idx >= len(block.lines):
return [block]
elif split_line_idx == 0:
return [block]
else:
new_blocks.append(Block(lines=block.lines[:split_line_idx], bbox=bbox_from_lines(block.lines[:split_line_idx]), pnum=block.pnum))
new_blocks.append(Block(lines=block.lines[split_line_idx:], bbox=bbox_from_lines(block.lines[split_line_idx:]), pnum=block.pnum))
return new_blocks

0 comments on commit df6f8fc

Please sign in to comment.