Skip to content

Commit

Permalink
Add in surya OCR
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 1, 2024
1 parent 51990c8 commit 30da488
Show file tree
Hide file tree
Showing 28 changed files with 655 additions and 517 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,19 @@ python convert.py /path/to/input/folder /path/to/output/folder --workers 10 --ma

- `--workers` is the number of pdfs to convert at once. This is set to 1 by default, but you can increase it to increase throughput, at the cost of more CPU/GPU usage. Parallelism will not increase beyond `INFERENCE_RAM / VRAM_PER_TASK` if you're using GPU.
- `--max` is the maximum number of pdfs to convert. Omit this to convert all pdfs in the folder.
- `--metadata_file` is an optional path to a json file with metadata about the pdfs. If you provide it, it will be used to set the language for each pdf. If not, `DEFAULT_LANG` will be used. The format is:
- `--min_length` is the minimum number of characters that need to be extracted from a pdf before it will be considered for processing. If you're processing a lot of pdfs, I recommend setting this to avoid OCRing pdfs that are mostly images. (slows everything down)
- `--metadata_file` is an optional path to a json file with metadata about the pdfs. If you provide it, it will be used to set the language for each pdf. If not, `DEFAULT_LANG` will be used. The format is:

```
{
"pdf1.pdf": {"language": "English"},
"pdf2.pdf": {"language": "Spanish"},
"pdf1.pdf": {"languages": ["English"]},
"pdf2.pdf": {"languages": ["Spanish", "Russian"]},
...
}
```

You can use language names or codes. See [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py) for a full list.

## Convert multiple files on multiple GPUs

Run `chunk_convert.sh`, like this:
Expand Down
2 changes: 1 addition & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from marker.logger import configure_logging
from marker.models import load_all_models
from marker.benchmark.scoring import score_text
from marker.extract_text import naive_get_text
from marker.pdf.extract_text import naive_get_text
import json
import os
import subprocess
Expand Down
8 changes: 7 additions & 1 deletion convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from tqdm import tqdm
import math

from marker.convert import convert_single_pdf, get_length_of_text
from marker.convert import convert_single_pdf
from marker.pdf.filetype import find_filetype
from marker.pdf.extract_text import get_length_of_text
from marker.models import load_all_models
from marker.settings import settings
from marker.logger import configure_logging
Expand All @@ -28,6 +30,10 @@ def process_single_pdf(fname: str, out_folder: str, model_refs, metadata: Option
# This can indicate that they were scanned, and not OCRed properly
# Usually these files are not recent/high-quality
if min_length:
filetype = find_filetype(fname)
if filetype == "other":
return 0

length = get_length_of_text(fname)
if length < min_length:
return
Expand Down
3 changes: 2 additions & 1 deletion marker/cleaners/code.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from marker.schema import Span, Line, Page
from marker.schema.schema import Span, Line
from marker.schema.page import Page
import re
from typing import List

Expand Down
25 changes: 7 additions & 18 deletions marker/cleaners/equations.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,20 @@
import io
from copy import deepcopy
from functools import partial
from typing import List

import torch
from texify.inference import batch_inference
from texify.model.model import load_model
from texify.model.processor import load_processor
import re

from PIL import Image, ImageDraw

from marker.bbox import should_merge_blocks, merge_boxes
from marker.schema.bbox import should_merge_blocks, merge_boxes
from marker.debug.data import dump_equation_debug_data
from marker.pdf.images import render_image
from marker.settings import settings
from marker.schema import Page, Span, Line, Block, BlockType
from marker.schema.schema import Span, Line, Block, BlockType
from marker.schema.page import Page
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

processor = load_processor()


def load_texify_model():
texify_model = load_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
return texify_model


def mask_bbox(png_image, bbox, selected_bboxes):
mask = Image.new('L', png_image.size, 0) # 'L' mode for grayscale
Expand Down Expand Up @@ -72,10 +61,10 @@ def get_latex_batched(images, reformat_region_lens, texify_model, batch_size):
max_length = min(max_length, settings.TEXIFY_MODEL_MAX)
max_length += settings.TEXIFY_TOKEN_BUFFER

model_output = batch_inference(images[min_idx:max_idx], texify_model, processor, max_tokens=max_length)
model_output = batch_inference(images[min_idx:max_idx], texify_model, texify_model.processor, max_tokens=max_length)

for j, output in enumerate(model_output):
token_count = get_total_texify_tokens(output)
token_count = get_total_texify_tokens(output, texify_model.processor)
if token_count >= max_length - 1:
output = ""

Expand All @@ -84,7 +73,7 @@ def get_latex_batched(images, reformat_region_lens, texify_model, batch_size):
return predictions


def get_total_texify_tokens(text):
def get_total_texify_tokens(text, processor):
tokenizer = processor.tokenizer
tokens = tokenizer(text)
return len(tokens["input_ids"])
Expand Down
3 changes: 2 additions & 1 deletion marker/cleaners/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from sklearn.cluster import DBSCAN
import numpy as np

from marker.schema import Page, FullyMergedBlock
from marker.schema.schema import FullyMergedBlock
from marker.schema.page import Page
from typing import List, Tuple


Expand Down
5 changes: 3 additions & 2 deletions marker/cleaners/table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from marker.bbox import merge_boxes
from marker.schema import Line, Span, Block, Page
from marker.schema.bbox import merge_boxes
from marker.schema.schema import Line, Span, Block
from marker.schema.page import Page
from copy import deepcopy
from tabulate import tabulate
from typing import List
Expand Down
107 changes: 42 additions & 65 deletions marker/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,125 +2,102 @@

from marker.cleaners.table import merge_table_blocks, create_new_tables
from marker.debug.data import dump_bbox_debug_data
from marker.extract_text import get_text_blocks
from marker.ocr.lang import replace_langs_with_codes, validate_langs
from marker.ocr.detection import surya_detection
from marker.ocr.recognition import run_ocr
from marker.pdf.extract_text import get_text_blocks
from marker.cleaners.headers import filter_header_footer, filter_common_titles
from marker.cleaners.equations import replace_equations
from marker.ordering import order_blocks
from marker.pdf.filetype import find_filetype
from marker.postprocessors.editor import edit_full_text
from marker.segmentation import detect_document_block_types
from marker.cleaners.code import identify_code_blocks, indent_blocks
from marker.cleaners.bullets import replace_bullets
from marker.markdown import merge_spans, merge_lines, get_full_text
from marker.schema import Page, BlockType
from marker.schema.schema import BlockType
from marker.schema.page import Page
from typing import List, Dict, Tuple, Optional
import re
import magic
from marker.settings import settings


def find_filetype(fpath):
mimetype = magic.from_file(fpath).lower()

# Get extensions from mimetype
# The mimetype is not always consistent, so use in to check the most common formats
if "pdf" in mimetype:
return "pdf"
#elif "epub" in mimetype:
# return "epub"
#elif "mobi" in mimetype:
# return "mobi"
elif mimetype in settings.SUPPORTED_FILETYPES:
return settings.SUPPORTED_FILETYPES[mimetype]
else:
print(f"Found nonstandard filetype {mimetype}")
return "other"


def annotate_spans(blocks: List[Page], block_types: List[BlockType]):
for i, page in enumerate(blocks):
page_block_types = block_types[i]
page.add_block_types(page_block_types)


def get_length_of_text(fname: str) -> int:
filetype = find_filetype(fname)
if filetype == "other":
return 0

doc = pdfium.PdfDocument(fname)
full_text = ""
for page_idx in range(len(doc)):
page = doc.get_page(page_idx)
text_page = page.get_textpage()
full_text += text_page.get_text_bounded()

return len(full_text)


def convert_single_pdf(
fname: str,
model_lst: List,
max_pages=None,
metadata: Optional[Dict]=None,
parallel_factor: int = 1
) -> Tuple[str, Dict]:
lang = settings.DEFAULT_LANG
# Set language needed for OCR
langs = [settings.DEFAULT_LANG]
if metadata:
lang = metadata.get("language", settings.DEFAULT_LANG)
langs = metadata.get("languages", langs)

# Use tesseract language if available
tess_lang = settings.TESSERACT_LANGUAGES.get(lang, "eng")
spell_lang = settings.SPELLCHECK_LANGUAGES.get(lang, None)
if "eng" not in tess_lang:
tess_lang = f"eng+{tess_lang}"

# Output metadata
out_meta = {"language": lang}
langs = replace_langs_with_codes(langs)
validate_langs(langs)

# Find the filetype
filetype = find_filetype(fname)
if filetype == "other":
return "", out_meta

out_meta["filetype"] = filetype
# Setup output metadata
out_meta = {
"languages": langs,
"filetype": filetype,
}

if filetype == "other": # We can't process this file
return "", out_meta

# Get initial text blocks from the pdf
doc = pdfium.PdfDocument(fname)
blocks, toc, ocr_stats = get_text_blocks(
pages, toc = get_text_blocks(
doc,
tess_lang,
spell_lang,
max_pages=max_pages,
parallel=int(parallel_factor * settings.OCR_PARALLEL_WORKERS)
)
out_meta.update({
"toc": toc,
"pages": len(pages),
})

# Unpack models from list
texify_model, layout_model, order_model, edit_model, detection_model, ocr_model = model_lst

# Identify text lines on pages
surya_detection(doc, pages, detection_model)

out_meta["toc"] = toc
out_meta["pages"] = len(blocks)
out_meta["ocr_stats"] = ocr_stats
if len([b for p in blocks for b in p.blocks]) == 0:
# OCR pages as needed
pages, ocr_stats = run_ocr(doc, pages, langs, ocr_model, parallel_factor)

if len([b for p in pages for b in p.blocks]) == 0:
print(f"Could not extract any text blocks for {fname}")
return "", out_meta

# Unpack models from list
texify_model, layoutlm_model, order_model, edit_model = model_lst

block_types = detect_document_block_types(
doc,
blocks,
pages,
layoutlm_model,
batch_size=int(settings.LAYOUT_BATCH_SIZE * parallel_factor)
)

# Find headers and footers
bad_span_ids = filter_header_footer(blocks)
bad_span_ids = filter_header_footer(pages)
out_meta["block_stats"] = {"header_footer": len(bad_span_ids)}

annotate_spans(blocks, block_types)
annotate_spans(pages, block_types)

# Dump debug data if flags are set
dump_bbox_debug_data(doc, blocks)
dump_bbox_debug_data(doc, pages)

blocks = order_blocks(
doc,
blocks,
pages,
order_model,
batch_size=int(settings.ORDERER_BATCH_SIZE * parallel_factor)
)
Expand Down
3 changes: 1 addition & 2 deletions marker/debug/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import base64
import json
import os
import zlib
from typing import List

from marker.pdf.images import render_image
from marker.schema import Page
from marker.schema.page import Page
from marker.settings import settings
from PIL import Image
import io
Expand Down
Empty file added marker/layout/layout.py
Empty file.
Empty file added marker/layout/order.py
Empty file.
3 changes: 2 additions & 1 deletion marker/markdown.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from marker.schema import MergedLine, MergedBlock, FullyMergedBlock, Page
from marker.schema.schema import MergedLine, MergedBlock, FullyMergedBlock
from marker.schema.page import Page
import re
from typing import List

Expand Down
59 changes: 51 additions & 8 deletions marker/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,56 @@
from marker.cleaners.equations import load_texify_model
from marker.ordering import load_ordering_model
from marker.postprocessors.editor import load_editing_model
from marker.segmentation import load_layout_model
from surya.model.detection import segformer
from texify.model.model import load_model as load_texify_model
from texify.model.processor import load_processor as load_texify_processor
from marker.settings import settings
from surya.model.recognition.model import load_model as load_recognition_model
from surya.model.recognition.processor import load_processor as load_recognition_processor
from surya.model.ordering.model import load_model as load_order_model
from surya.model.ordering.processor import load_processor as load_order_processor


def load_all_models():
def setup_recognition_model(langs):
rec_model = load_recognition_model(langs=langs)
rec_processor = load_recognition_processor()
rec_model.processor = rec_processor
return rec_model


def setup_detection_model():
model = segformer.load_model()
processor = segformer.load_processor()
model.processor = processor
return model


def setup_texify_model():
texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE)
texify_processor = load_texify_processor()
texify_model.processor = texify_processor
return texify_model


def setup_layout_model():
model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
processor = segformer.load_processor()
model.processor = processor
return model


def setup_order_model():
model = load_order_model()
processor = load_order_processor()
model.processor = processor
return model


def load_all_models(langs=None):
# langs is optional list of languages to prune from recognition MoE model
detection = setup_detection_model()
layout = setup_layout_model()
order = setup_order_model()
edit = load_editing_model()
order = load_ordering_model()
layout = load_layout_model()
texify = load_texify_model()
model_lst = [texify, layout, order, edit]
ocr = setup_recognition_model(langs)
texify = setup_texify_model()
model_lst = [texify, layout, order, edit, detection, ocr]
return model_lst
Loading

0 comments on commit 30da488

Please sign in to comment.