diff --git a/README.md b/README.md index 67456b45..3950a82a 100644 --- a/README.md +++ b/README.md @@ -123,17 +123,19 @@ python convert.py /path/to/input/folder /path/to/output/folder --workers 10 --ma - `--workers` is the number of pdfs to convert at once. This is set to 1 by default, but you can increase it to increase throughput, at the cost of more CPU/GPU usage. Parallelism will not increase beyond `INFERENCE_RAM / VRAM_PER_TASK` if you're using GPU. - `--max` is the maximum number of pdfs to convert. Omit this to convert all pdfs in the folder. -- `--metadata_file` is an optional path to a json file with metadata about the pdfs. If you provide it, it will be used to set the language for each pdf. If not, `DEFAULT_LANG` will be used. The format is: - `--min_length` is the minimum number of characters that need to be extracted from a pdf before it will be considered for processing. If you're processing a lot of pdfs, I recommend setting this to avoid OCRing pdfs that are mostly images. (slows everything down) +- `--metadata_file` is an optional path to a json file with metadata about the pdfs. If you provide it, it will be used to set the language for each pdf. If not, `DEFAULT_LANG` will be used. The format is: ``` { - "pdf1.pdf": {"language": "English"}, - "pdf2.pdf": {"language": "Spanish"}, + "pdf1.pdf": {"languages": ["English"]}, + "pdf2.pdf": {"languages": ["Spanish", "Russian"]}, ... } ``` +You can use language names or codes. See [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py) for a full list. + ## Convert multiple files on multiple GPUs Run `chunk_convert.sh`, like this: diff --git a/benchmark.py b/benchmark.py index 59c5de59..bd515c42 100644 --- a/benchmark.py +++ b/benchmark.py @@ -9,7 +9,7 @@ from marker.logger import configure_logging from marker.models import load_all_models from marker.benchmark.scoring import score_text -from marker.extract_text import naive_get_text +from marker.pdf.extract_text import naive_get_text import json import os import subprocess diff --git a/convert.py b/convert.py index f9837440..96b567fd 100755 --- a/convert.py +++ b/convert.py @@ -6,7 +6,9 @@ from tqdm import tqdm import math -from marker.convert import convert_single_pdf, get_length_of_text +from marker.convert import convert_single_pdf +from marker.pdf.filetype import find_filetype +from marker.pdf.extract_text import get_length_of_text from marker.models import load_all_models from marker.settings import settings from marker.logger import configure_logging @@ -28,6 +30,10 @@ def process_single_pdf(fname: str, out_folder: str, model_refs, metadata: Option # This can indicate that they were scanned, and not OCRed properly # Usually these files are not recent/high-quality if min_length: + filetype = find_filetype(fname) + if filetype == "other": + return 0 + length = get_length_of_text(fname) if length < min_length: return diff --git a/marker/cleaners/code.py b/marker/cleaners/code.py index 1146ee78..6ce612dd 100644 --- a/marker/cleaners/code.py +++ b/marker/cleaners/code.py @@ -1,4 +1,5 @@ -from marker.schema import Span, Line, Page +from marker.schema.schema import Span, Line +from marker.schema.page import Page import re from typing import List diff --git a/marker/cleaners/equations.py b/marker/cleaners/equations.py index 2fc7c9a4..be995784 100644 --- a/marker/cleaners/equations.py +++ b/marker/cleaners/equations.py @@ -1,31 +1,20 @@ -import io from copy import deepcopy -from functools import partial from typing import List -import torch from texify.inference import batch_inference -from texify.model.model import load_model -from texify.model.processor import load_processor -import re + from PIL import Image, ImageDraw -from marker.bbox import should_merge_blocks, merge_boxes +from marker.schema.bbox import should_merge_blocks, merge_boxes from marker.debug.data import dump_equation_debug_data from marker.pdf.images import render_image from marker.settings import settings -from marker.schema import Page, Span, Line, Block, BlockType +from marker.schema.schema import Span, Line, Block, BlockType +from marker.schema.page import Page import os os.environ["TOKENIZERS_PARALLELISM"] = "false" -processor = load_processor() - - -def load_texify_model(): - texify_model = load_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) - return texify_model - def mask_bbox(png_image, bbox, selected_bboxes): mask = Image.new('L', png_image.size, 0) # 'L' mode for grayscale @@ -72,10 +61,10 @@ def get_latex_batched(images, reformat_region_lens, texify_model, batch_size): max_length = min(max_length, settings.TEXIFY_MODEL_MAX) max_length += settings.TEXIFY_TOKEN_BUFFER - model_output = batch_inference(images[min_idx:max_idx], texify_model, processor, max_tokens=max_length) + model_output = batch_inference(images[min_idx:max_idx], texify_model, texify_model.processor, max_tokens=max_length) for j, output in enumerate(model_output): - token_count = get_total_texify_tokens(output) + token_count = get_total_texify_tokens(output, texify_model.processor) if token_count >= max_length - 1: output = "" @@ -84,7 +73,7 @@ def get_latex_batched(images, reformat_region_lens, texify_model, batch_size): return predictions -def get_total_texify_tokens(text): +def get_total_texify_tokens(text, processor): tokenizer = processor.tokenizer tokens = tokenizer(text) return len(tokens["input_ids"]) diff --git a/marker/cleaners/headers.py b/marker/cleaners/headers.py index cd30981b..32d257ac 100644 --- a/marker/cleaners/headers.py +++ b/marker/cleaners/headers.py @@ -6,7 +6,8 @@ from sklearn.cluster import DBSCAN import numpy as np -from marker.schema import Page, FullyMergedBlock +from marker.schema.schema import FullyMergedBlock +from marker.schema.page import Page from typing import List, Tuple diff --git a/marker/cleaners/table.py b/marker/cleaners/table.py index 1025238d..306b613e 100644 --- a/marker/cleaners/table.py +++ b/marker/cleaners/table.py @@ -1,5 +1,6 @@ -from marker.bbox import merge_boxes -from marker.schema import Line, Span, Block, Page +from marker.schema.bbox import merge_boxes +from marker.schema.schema import Line, Span, Block +from marker.schema.page import Page from copy import deepcopy from tabulate import tabulate from typing import List diff --git a/marker/convert.py b/marker/convert.py index e275d3fd..4c00ec72 100644 --- a/marker/convert.py +++ b/marker/convert.py @@ -2,61 +2,32 @@ from marker.cleaners.table import merge_table_blocks, create_new_tables from marker.debug.data import dump_bbox_debug_data -from marker.extract_text import get_text_blocks +from marker.ocr.lang import replace_langs_with_codes, validate_langs +from marker.ocr.detection import surya_detection +from marker.ocr.recognition import run_ocr +from marker.pdf.extract_text import get_text_blocks from marker.cleaners.headers import filter_header_footer, filter_common_titles from marker.cleaners.equations import replace_equations from marker.ordering import order_blocks +from marker.pdf.filetype import find_filetype from marker.postprocessors.editor import edit_full_text from marker.segmentation import detect_document_block_types from marker.cleaners.code import identify_code_blocks, indent_blocks from marker.cleaners.bullets import replace_bullets from marker.markdown import merge_spans, merge_lines, get_full_text -from marker.schema import Page, BlockType +from marker.schema.schema import BlockType +from marker.schema.page import Page from typing import List, Dict, Tuple, Optional import re -import magic from marker.settings import settings -def find_filetype(fpath): - mimetype = magic.from_file(fpath).lower() - - # Get extensions from mimetype - # The mimetype is not always consistent, so use in to check the most common formats - if "pdf" in mimetype: - return "pdf" - #elif "epub" in mimetype: - # return "epub" - #elif "mobi" in mimetype: - # return "mobi" - elif mimetype in settings.SUPPORTED_FILETYPES: - return settings.SUPPORTED_FILETYPES[mimetype] - else: - print(f"Found nonstandard filetype {mimetype}") - return "other" - - def annotate_spans(blocks: List[Page], block_types: List[BlockType]): for i, page in enumerate(blocks): page_block_types = block_types[i] page.add_block_types(page_block_types) -def get_length_of_text(fname: str) -> int: - filetype = find_filetype(fname) - if filetype == "other": - return 0 - - doc = pdfium.PdfDocument(fname) - full_text = "" - for page_idx in range(len(doc)): - page = doc.get_page(page_idx) - text_page = page.get_textpage() - full_text += text_page.get_text_bounded() - - return len(full_text) - - def convert_single_pdf( fname: str, model_lst: List, @@ -64,63 +35,69 @@ def convert_single_pdf( metadata: Optional[Dict]=None, parallel_factor: int = 1 ) -> Tuple[str, Dict]: - lang = settings.DEFAULT_LANG + # Set language needed for OCR + langs = [settings.DEFAULT_LANG] if metadata: - lang = metadata.get("language", settings.DEFAULT_LANG) + langs = metadata.get("languages", langs) - # Use tesseract language if available - tess_lang = settings.TESSERACT_LANGUAGES.get(lang, "eng") - spell_lang = settings.SPELLCHECK_LANGUAGES.get(lang, None) - if "eng" not in tess_lang: - tess_lang = f"eng+{tess_lang}" - - # Output metadata - out_meta = {"language": lang} + langs = replace_langs_with_codes(langs) + validate_langs(langs) + # Find the filetype filetype = find_filetype(fname) - if filetype == "other": - return "", out_meta - out_meta["filetype"] = filetype + # Setup output metadata + out_meta = { + "languages": langs, + "filetype": filetype, + } + if filetype == "other": # We can't process this file + return "", out_meta + + # Get initial text blocks from the pdf doc = pdfium.PdfDocument(fname) - blocks, toc, ocr_stats = get_text_blocks( + pages, toc = get_text_blocks( doc, - tess_lang, - spell_lang, max_pages=max_pages, - parallel=int(parallel_factor * settings.OCR_PARALLEL_WORKERS) ) + out_meta.update({ + "toc": toc, + "pages": len(pages), + }) + + # Unpack models from list + texify_model, layout_model, order_model, edit_model, detection_model, ocr_model = model_lst + + # Identify text lines on pages + surya_detection(doc, pages, detection_model) - out_meta["toc"] = toc - out_meta["pages"] = len(blocks) - out_meta["ocr_stats"] = ocr_stats - if len([b for p in blocks for b in p.blocks]) == 0: + # OCR pages as needed + pages, ocr_stats = run_ocr(doc, pages, langs, ocr_model, parallel_factor) + + if len([b for p in pages for b in p.blocks]) == 0: print(f"Could not extract any text blocks for {fname}") return "", out_meta - # Unpack models from list - texify_model, layoutlm_model, order_model, edit_model = model_lst - block_types = detect_document_block_types( doc, - blocks, + pages, layoutlm_model, batch_size=int(settings.LAYOUT_BATCH_SIZE * parallel_factor) ) # Find headers and footers - bad_span_ids = filter_header_footer(blocks) + bad_span_ids = filter_header_footer(pages) out_meta["block_stats"] = {"header_footer": len(bad_span_ids)} - annotate_spans(blocks, block_types) + annotate_spans(pages, block_types) # Dump debug data if flags are set - dump_bbox_debug_data(doc, blocks) + dump_bbox_debug_data(doc, pages) blocks = order_blocks( doc, - blocks, + pages, order_model, batch_size=int(settings.ORDERER_BATCH_SIZE * parallel_factor) ) diff --git a/marker/debug/data.py b/marker/debug/data.py index 03566cdc..02d59481 100644 --- a/marker/debug/data.py +++ b/marker/debug/data.py @@ -1,11 +1,10 @@ import base64 import json import os -import zlib from typing import List from marker.pdf.images import render_image -from marker.schema import Page +from marker.schema.page import Page from marker.settings import settings from PIL import Image import io diff --git a/marker/layout/layout.py b/marker/layout/layout.py new file mode 100644 index 00000000..e69de29b diff --git a/marker/layout/order.py b/marker/layout/order.py new file mode 100644 index 00000000..e69de29b diff --git a/marker/markdown.py b/marker/markdown.py index 33f475ee..7a374e50 100644 --- a/marker/markdown.py +++ b/marker/markdown.py @@ -1,4 +1,5 @@ -from marker.schema import MergedLine, MergedBlock, FullyMergedBlock, Page +from marker.schema.schema import MergedLine, MergedBlock, FullyMergedBlock +from marker.schema.page import Page import re from typing import List diff --git a/marker/models.py b/marker/models.py index e8a1ee65..9e61fdb7 100644 --- a/marker/models.py +++ b/marker/models.py @@ -1,13 +1,56 @@ -from marker.cleaners.equations import load_texify_model -from marker.ordering import load_ordering_model from marker.postprocessors.editor import load_editing_model -from marker.segmentation import load_layout_model +from surya.model.detection import segformer +from texify.model.model import load_model as load_texify_model +from texify.model.processor import load_processor as load_texify_processor +from marker.settings import settings +from surya.model.recognition.model import load_model as load_recognition_model +from surya.model.recognition.processor import load_processor as load_recognition_processor +from surya.model.ordering.model import load_model as load_order_model +from surya.model.ordering.processor import load_processor as load_order_processor -def load_all_models(): +def setup_recognition_model(langs): + rec_model = load_recognition_model(langs=langs) + rec_processor = load_recognition_processor() + rec_model.processor = rec_processor + return rec_model + + +def setup_detection_model(): + model = segformer.load_model() + processor = segformer.load_processor() + model.processor = processor + return model + + +def setup_texify_model(): + texify_model = load_texify_model(checkpoint=settings.TEXIFY_MODEL_NAME, device=settings.TORCH_DEVICE_MODEL, dtype=settings.TEXIFY_DTYPE) + texify_processor = load_texify_processor() + texify_model.processor = texify_processor + return texify_model + + +def setup_layout_model(): + model = segformer.load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + processor = segformer.load_processor() + model.processor = processor + return model + + +def setup_order_model(): + model = load_order_model() + processor = load_order_processor() + model.processor = processor + return model + + +def load_all_models(langs=None): + # langs is optional list of languages to prune from recognition MoE model + detection = setup_detection_model() + layout = setup_layout_model() + order = setup_order_model() edit = load_editing_model() - order = load_ordering_model() - layout = load_layout_model() - texify = load_texify_model() - model_lst = [texify, layout, order, edit] + ocr = setup_recognition_model(langs) + texify = setup_texify_model() + model_lst = [texify, layout, order, edit, detection, ocr] return model_lst diff --git a/marker/ocr/detection.py b/marker/ocr/detection.py new file mode 100644 index 00000000..fc58fe09 --- /dev/null +++ b/marker/ocr/detection.py @@ -0,0 +1,22 @@ +from typing import List + +from pypdfium2 import PdfDocument +from surya.detection import batch_text_detection + +from marker.pdf.images import render_image +from marker.schema.page import Page +from marker.settings import settings + + +def surya_detection(doc: PdfDocument, pages: List[Page], det_model): + processor = det_model.processor + max_len = min(len(pages), len(doc)) + images = [render_image(doc[pnum], dpi=settings.SURYA_DETECTOR_DPI) for pnum in range(max_len)] + + predictions = batch_text_detection(images, det_model, processor) + for (page, pred) in zip(pages, predictions): + page.text_lines = pred + + + + diff --git a/marker/ocr/heuristics.py b/marker/ocr/heuristics.py new file mode 100644 index 00000000..0bd7a9f5 --- /dev/null +++ b/marker/ocr/heuristics.py @@ -0,0 +1,71 @@ +import re +from typing import List + +from nltk import wordpunct_tokenize + +from marker.ocr.utils import alphanum_ratio +from marker.schema.page import Page +from marker.settings import settings + + +def should_ocr_page(page: Page, no_text: bool): + detected_lines_found = detected_line_coverage(page) + + # OCR page if we got minimal text, or if we got too many spaces + conditions = [ + no_text , # Full doc has no text, and needs full OCR + (len(page.prelim_text) > 0 and detect_bad_ocr(page.prelim_text)), # Bad OCR + detected_lines_found is False, # didn't extract text for all detected lines + ] + + return any(conditions) or settings.OCR_ALL_PAGES + + +def detect_bad_ocr(text, space_threshold=.6, newline_threshold=.5, alphanum_threshold=.4): + if len(text) == 0: + # Assume OCR failed if we have no text + return True + + words = wordpunct_tokenize(text) + words = [w for w in words if w.strip()] + alpha_words = [word for word in words if word.isalnum()] + + spaces = len(re.findall(r'\s+', text)) + alpha_chars = len(re.sub(r'\s+', '', text)) + if spaces / (alpha_chars + spaces) > space_threshold: + return True + + newlines = len(re.findall(r'\n+', text)) + non_newlines = len(re.sub(r'\n+', '', text)) + if newlines / (newlines + non_newlines) > newline_threshold: + return True + + if alphanum_ratio(text) < alphanum_threshold: # Garbled text + return True + + invalid_chars = len([c for c in text if c in settings.INVALID_CHARS]) + if invalid_chars > max(3.0, len(text) * .02): + return True + + return False + + +def no_text_found(pages: List[Page]): + full_text = "" + for page in pages: + full_text += page.text + return len(full_text.strip()) < 10 + + +def detected_line_coverage(page: Page, intersect_thresh=.6, detection_thresh=.5): + found_lines = 0 + total_lines = 0 + for detected_line in page.text_lines.bboxes: + detected_bbox = detected_line.bbox + for block in page.blocks: + for line in block.lines: + intersection_pct = line.intersection_pct(detected_bbox) + if intersection_pct > intersect_thresh: + found_lines += 1 + total_lines += 1 + return found_lines / total_lines > detection_thresh diff --git a/marker/ocr/lang.py b/marker/ocr/lang.py new file mode 100644 index 00000000..e4ed7acb --- /dev/null +++ b/marker/ocr/lang.py @@ -0,0 +1,14 @@ +from surya.languages import CODE_TO_LANGUAGE, LANGUAGE_TO_CODE + + +def replace_langs_with_codes(langs): + for i, lang in enumerate(langs): + if lang in LANGUAGE_TO_CODE: + langs[i] = LANGUAGE_TO_CODE[lang] + return langs + + +def validate_langs(langs): + for lang in langs: + if lang not in CODE_TO_LANGUAGE: + raise ValueError(f"Invalid language code {lang}") \ No newline at end of file diff --git a/marker/ocr/page.py b/marker/ocr/page.py deleted file mode 100644 index cdb52f6f..00000000 --- a/marker/ocr/page.py +++ /dev/null @@ -1,75 +0,0 @@ -import io -from typing import List, Optional - -import ocrmypdf - -from marker.ocr.utils import detect_bad_ocr -from marker.schema import Block -from marker.settings import settings - -ocrmypdf.configure_logging(verbosity=ocrmypdf.Verbosity.quiet) - - -def ocr_entire_page(page, lang: str) -> List[Block]: - if settings.OCR_ENGINE == "tesseract": - return ocr_entire_page_tess(page, lang) - elif settings.OCR_ENGINE == "ocrmypdf": - return ocr_entire_page_ocrmp(page, lang) - else: - raise ValueError(f"Unknown OCR engine {settings.OCR_ENGINE}") - - -def ocr_entire_page_tess(page, lang: str) -> List[Block]: - try: - full_tp = page.get_textpage_ocr(flags=settings.TEXT_FLAGS, dpi=settings.OCR_DPI, full=True, language=lang) - blocks = page.get_text("dict", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp)["blocks"] - full_text = page.get_text("text", sort=True, flags=settings.TEXT_FLAGS, textpage=full_tp) - - if len(full_text) == 0: - return [] - - # Check if OCR worked. If it didn't, return empty list - # OCR can fail if there is a scanned blank page with some faint text impressions, for example - if detect_bad_ocr(full_text): - return [] - except RuntimeError: - return [] - return blocks - - -def ocr_entire_page_ocrmp(page, lang: str) -> List[Block]: - # Use ocrmypdf to get OCR text for the whole page - src = page.parent # the page's document - blank_doc = pymupdf.open() # make temporary 1-pager - blank_doc.insert_pdf(src, from_page=page.number, to_page=page.number, annots=False, links=False) - pdfbytes = blank_doc.tobytes() - inbytes = io.BytesIO(pdfbytes) # transform to BytesIO object - outbytes = io.BytesIO() # let ocrmypdf store its result pdf here - ocrmypdf.ocr( - inbytes, - outbytes, - language=lang, - output_type="pdf", - redo_ocr=None if settings.OCR_ALL_PAGES else True, - force_ocr=True if settings.OCR_ALL_PAGES else None, - progress_bar=False, - optimize=False, - fast_web_view=1e6, - skip_big=15, # skip images larger than 15 megapixels - tesseract_timeout=settings.TESSERACT_TIMEOUT, - tesseract_non_ocr_timeout=settings.TESSERACT_TIMEOUT, - ) - ocr_pdf = pymupdf.open("pdf", outbytes.getvalue()) # read output as fitz PDF - blocks = ocr_pdf[0].get_text("dict", sort=True, flags=settings.TEXT_FLAGS)["blocks"] - full_text = ocr_pdf[0].get_text("text", sort=True, flags=settings.TEXT_FLAGS) - - # Make sure the original pdf/epub/mobi bbox and the ocr pdf bbox are the same - assert page.bound() == ocr_pdf[0].bound() - - if len(full_text) == 0: - return [] - - if detect_bad_ocr(full_text): - return [] - - return blocks diff --git a/marker/ocr/recognition.py b/marker/ocr/recognition.py new file mode 100644 index 00000000..a439af45 --- /dev/null +++ b/marker/ocr/recognition.py @@ -0,0 +1,136 @@ +from itertools import repeat +from typing import List, Optional, Dict + +import ocrmypdf +import pypdfium2 as pdfium +import io +from concurrent.futures import ThreadPoolExecutor + +from surya.ocr import run_recognition + +from marker.ocr.heuristics import should_ocr_page, no_text_found, detect_bad_ocr +from marker.pdf.images import render_image +from marker.schema.page import Page +from marker.schema.schema import Block, Line, Span +from marker.settings import settings +from marker.pdf.extract_text import get_text_blocks + + +def run_ocr(doc, pages: List[Page], langs: List[str], rec_model, parallel_factor) -> (List[Page], Dict): + ocr_pages = 0 + ocr_success = 0 + ocr_failed = 0 + no_text = no_text_found(pages) + ocr_idxs = [] + for pnum, page in enumerate(pages): + ocr_needed = should_ocr_page(page, no_text) + if ocr_needed: + ocr_idxs.append(pnum) + ocr_pages += 1 + + ocr_method = settings.OCR_ENGINE_INTERNAL + if ocr_method == "surya": + new_pages = surya_recognition(doc, ocr_idxs, langs, rec_model, pages) + else: + new_pages = tesseract_recognition(doc, ocr_idxs, langs) + + for orig_idx, page in zip(ocr_idxs, new_pages): + if detect_bad_ocr(page) or len(page.prelim_text) == 0: + ocr_failed += 1 + else: + ocr_success += 1 + pages[orig_idx] = page + + return pages, {"ocr_pages": ocr_pages, "ocr_failed": ocr_failed, "ocr_success": ocr_success} + + +def surya_recognition(doc, page_idxs, langs: List[str], rec_model, pages: List[Page]) -> List[Optional[Page]]: + images = [render_image(doc[pnum], dpi=settings.SURYA_OCR_DPI) for pnum in page_idxs] + processor = rec_model.processor + selected_pages = [p for i, p in enumerate(pages) if i in page_idxs] + + surya_langs = [langs] * len(page_idxs) + detection_results = [p.text_lines.bboxes for p in selected_pages] + polygons = [[b.polygon for b in bboxes] for bboxes in detection_results] + + results = run_recognition(images, surya_langs, rec_model, processor, polygons=polygons) + + new_pages = [] + for (page_idx, result, old_page) in zip(page_idxs, results, selected_pages): + text_lines = old_page.text_lines + ocr_results = result.text_lines + blocks = [] + for i, line in enumerate(ocr_results): + block = Block( + bbox=line.bbox, + pnum=page_idx, + lines=[Line( + bbox=line.bbox, + spans=[Span( + text=line.text, + bbox=line.bbox, + span_id=f"{page_idx}_{i}", + font="", + font_weight=0, + font_size=0, + ) + ] + )] + ) + blocks.append(block) + page = Page( + blocks=blocks, + pnum=page_idx, + bbox=old_page.bbox, + rotation=old_page.rotation, + text_lines=text_lines + ) + new_pages.append(page) + return new_pages + + +def tesseract_recognition(doc, page_idxs, langs: List[str]) -> List[Optional[Page]]: + pdf_pages = generate_single_page_pdfs(doc, page_idxs) + with ThreadPoolExecutor(max_workers=settings.OCR_THREADS) as executor: + pages = list(executor.map(_tesseract_recognition, pdf_pages, repeat(langs, len(pdf_pages)))) + + return pages + + +def generate_single_page_pdfs(doc, page_idxs) -> List[io.BytesIO]: + pdf_pages = [] + for page_idx in page_idxs: + blank_doc = pdfium.PdfDocument.new() + blank_doc.import_pages(doc, pages=[page_idx]) + assert len(blank_doc) == 1, "Failed to import page" + + in_pdf = io.BytesIO() + blank_doc.save(in_pdf) + in_pdf.seek(0) + pdf_pages.append(in_pdf) + return pdf_pages + + +def _tesseract_recognition(in_pdf, langs: List[str]) -> Optional[Page]: + out_pdf = io.BytesIO() + + ocrmypdf.ocr( + in_pdf, + out_pdf, + language=langs[0], + output_type="pdf", + redo_ocr=None if settings.OCR_ALL_PAGES else True, + force_ocr=True if settings.OCR_ALL_PAGES else None, + progress_bar=False, + optimize=False, + fast_web_view=1e6, + skip_big=15, # skip images larger than 15 megapixels + tesseract_timeout=settings.TESSERACT_TIMEOUT, + tesseract_non_ocr_timeout=settings.TESSERACT_TIMEOUT, + ) + + new_doc = pdfium.PdfDocument(out_pdf.getvalue()) + + blocks, _ = get_text_blocks(new_doc, max_pages=1) + page = blocks[0] + return page diff --git a/marker/ocr/utils.py b/marker/ocr/utils.py index 33d1207d..c94682cd 100644 --- a/marker/ocr/utils.py +++ b/marker/ocr/utils.py @@ -1,39 +1,3 @@ -from typing import Optional - -from nltk import wordpunct_tokenize -from marker.settings import settings -import re - - -def detect_bad_ocr(text, space_threshold=.6, newline_threshold=.5, alphanum_threshold=.4): - if len(text) == 0: - # Assume OCR failed if we have no text - return True - - words = wordpunct_tokenize(text) - words = [w for w in words if w.strip()] - alpha_words = [word for word in words if word.isalnum()] - - spaces = len(re.findall(r'\s+', text)) - alpha_chars = len(re.sub(r'\s+', '', text)) - if spaces / (alpha_chars + spaces) > space_threshold: - return True - - newlines = len(re.findall(r'\n+', text)) - non_newlines = len(re.sub(r'\n+', '', text)) - if newlines / (newlines + non_newlines) > newline_threshold: - return True - - if alphanum_ratio(text) < alphanum_threshold: # Garbled text - return True - - invalid_chars = len([c for c in text if c in settings.INVALID_CHARS]) - if invalid_chars > max(3.0, len(text) * .02): - return True - - return False - - def font_flags_decomposer(flags): flags = int(flags) diff --git a/marker/ordering.py b/marker/ordering.py index ef172663..fc282a33 100644 --- a/marker/ordering.py +++ b/marker/ordering.py @@ -2,14 +2,11 @@ from typing import List import torch -import sys, os from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3Processor -from PIL import Image -import io from marker.pdf.images import render_image -from marker.schema import Page +from marker.schema.page import Page from marker.settings import settings processor = LayoutLMv3Processor.from_pretrained(settings.ORDERER_MODEL_NAME) diff --git a/marker/extract_text.py b/marker/pdf/extract_text.py similarity index 62% rename from marker/extract_text.py rename to marker/pdf/extract_text.py index 9dff3b7a..c795b284 100644 --- a/marker/extract_text.py +++ b/marker/pdf/extract_text.py @@ -1,17 +1,21 @@ import os -from typing import List, Optional +from typing import List, Optional, Dict +import pypdfium2 as pdfium import pypdfium2.internal as pdfium_i -from marker.ocr.utils import detect_bad_ocr, font_flags_decomposer +from marker.pdf.filetype import find_filetype +from marker.ocr.utils import font_flags_decomposer +from marker.ocr.heuristics import detect_bad_ocr from marker.settings import settings -from marker.schema import Span, Line, Block, Page +from marker.schema.schema import Span, Line, Block +from marker.schema.page import Page from pdftext.extraction import dictionary_output os.environ["TESSDATA_PREFIX"] = settings.TESSDATA_PREFIX -def pdftext_format_to_blocks(page, pnum: int) -> List[Block]: +def pdftext_format_to_blocks(page, pnum: int) -> Page: page_blocks = [] span_id = 0 for block_idx, block in enumerate(page["blocks"]): @@ -54,42 +58,8 @@ def pdftext_format_to_blocks(page, pnum: int) -> List[Block]: return out_page -def ocr_page(doc, pnum, page: Page, tess_lang: str): - ocr_pages = 0 - ocr_success = 0 - ocr_failed = 0 - page_bbox = doc[pnum].bound() - - blocks = get_single_page_blocks(doc, pnum, tess_lang) - page_obj = Page(blocks=blocks, pnum=pnum, bbox=page_bbox) - - # OCR page if we got minimal text, or if we got too many spaces - conditions = [ - ( - no_text # Full doc has no text, and needs full OCR - or - (len(page_obj.prelim_text) > 0 and detect_bad_ocr(page_obj.prelim_text)) # Bad OCR - ), - min_ocr_page < pnum < len(doc) - 1, - not disable_ocr - ] - if all(conditions) or settings.OCR_ALL_PAGES: - page = doc[pnum] - blocks = get_single_page_blocks(doc, pnum, tess_lang, ocr=True) - page_obj = Page(blocks=blocks, pnum=pnum, bbox=page_bbox, rotation=page.rotation) - ocr_pages = 1 - if len(blocks) == 0: - ocr_failed = 1 - else: - ocr_success = 1 - return page_obj, {"ocr_pages": ocr_pages, "ocr_failed": ocr_failed, "ocr_success": ocr_success} - - -def get_text_blocks(doc, tess_lang: str, spell_lang: Optional[str], max_pages: Optional[int] = None, parallel: int = settings.OCR_PARALLEL_WORKERS): +def get_text_blocks(doc, max_pages: Optional[int] = None) -> (List[Page], Dict): toc = get_toc(doc) - ocr_pages = 0 - ocr_failed = 0 - ocr_success = 0 page_range = range(len(doc)) if max_pages: @@ -99,7 +69,7 @@ def get_text_blocks(doc, tess_lang: str, spell_lang: Optional[str], max_pages: O all_blocks = dictionary_output(doc, page_range=page_range) all_blocks = [pdftext_format_to_blocks(page, pnum) for pnum, page in enumerate(all_blocks)] - return all_blocks, toc, {"ocr_pages": ocr_pages, "ocr_failed": ocr_failed, "ocr_success": ocr_success} + return all_blocks, toc def naive_get_text(doc): @@ -126,3 +96,10 @@ def get_toc(doc, max_depth=15): } toc_list.append(list_item) return toc_list + + +def get_length_of_text(fname: str) -> int: + doc = pdfium.PdfDocument(fname) + text = naive_get_text(doc).strip() + + return len(text) diff --git a/marker/pdf/filetype.py b/marker/pdf/filetype.py new file mode 100644 index 00000000..7311e061 --- /dev/null +++ b/marker/pdf/filetype.py @@ -0,0 +1,21 @@ +import magic + +from marker.settings import settings + + +def find_filetype(fpath): + mimetype = magic.from_file(fpath).lower() + + # Get extensions from mimetype + # The mimetype is not always consistent, so use in to check the most common formats + if "pdf" in mimetype: + return "pdf" + #elif "epub" in mimetype: + # return "epub" + #elif "mobi" in mimetype: + # return "mobi" + elif mimetype in settings.SUPPORTED_FILETYPES: + return settings.SUPPORTED_FILETYPES[mimetype] + else: + print(f"Found nonstandard filetype {mimetype}") + return "other" diff --git a/marker/schema.py b/marker/schema.py deleted file mode 100644 index 224ac69e..00000000 --- a/marker/schema.py +++ /dev/null @@ -1,212 +0,0 @@ -from collections import Counter -from typing import List, Optional, Tuple - -from pydantic import BaseModel, field_validator -import ftfy - -from marker.bbox import boxes_intersect_pct, multiple_boxes_intersect -from marker.settings import settings - - -def find_span_type(span, page_blocks): - block_type = "Text" - for block in page_blocks: - if boxes_intersect_pct(span.bbox, block.bbox): - block_type = block.block_type - break - return block_type - - -class BboxElement(BaseModel): - bbox: List[float] - - @field_validator('bbox') - @classmethod - def check_4_elements(cls, v: List[float]) -> List[float]: - if len(v) != 4: - raise ValueError('bbox must have 4 elements') - return v - - @property - def height(self): - return self.bbox[3] - self.bbox[1] - - @property - def width(self): - return self.bbox[2] - self.bbox[0] - - @property - def x_start(self): - return self.bbox[0] - - @property - def y_start(self): - return self.bbox[1] - - @property - def area(self): - return self.width * self.height - - -class BlockType(BboxElement): - block_type: str - - -class Span(BboxElement): - text: str - span_id: str - font: str - font_weight: float - font_size: float - block_type: Optional[str] = None - selected: bool = True - - - @field_validator('text') - @classmethod - def fix_unicode(cls, text: str) -> str: - return ftfy.fix_text(text) - - -class Line(BboxElement): - spans: List[Span] - - @property - def prelim_text(self): - return "".join([s.text for s in self.spans]) - - @property - def start(self): - return self.spans[0].bbox[0] - - -class Block(BboxElement): - lines: List[Line] - pnum: int - - @property - def prelim_text(self): - return "\n".join([l.prelim_text for l in self.lines]) - - def contains_equation(self, equation_boxes=None): - conditions = [s.block_type == "Formula" for l in self.lines for s in l.spans] - if equation_boxes: - conditions += [multiple_boxes_intersect(self.bbox, equation_boxes)] - return any(conditions) - - def filter_spans(self, bad_span_ids): - new_lines = [] - for line in self.lines: - new_spans = [] - for span in line.spans: - if not span.span_id in bad_span_ids: - new_spans.append(span) - line.spans = new_spans - if len(new_spans) > 0: - new_lines.append(line) - self.lines = new_lines - - def filter_bad_span_types(self): - new_lines = [] - for line in self.lines: - new_spans = [] - for span in line.spans: - if span.block_type not in settings.BAD_SPAN_TYPES: - new_spans.append(span) - line.spans = new_spans - if len(new_spans) > 0: - new_lines.append(line) - self.lines = new_lines - - def most_common_block_type(self): - counter = Counter([s.block_type for l in self.lines for s in l.spans]) - return counter.most_common(1)[0][0] - - def set_block_type(self, block_type): - for line in self.lines: - for span in line.spans: - span.block_type = block_type - - -class Page(BboxElement): - blocks: List[Block] - pnum: int - column_count: Optional[int] = None - rotation: Optional[int] = None # Rotation degrees of the page - - def get_nonblank_lines(self): - lines = self.get_all_lines() - nonblank_lines = [l for l in lines if l.prelim_text.strip()] - return nonblank_lines - - def get_all_lines(self): - lines = [l for b in self.blocks for l in b.lines] - return lines - - def get_nonblank_spans(self) -> List[Span]: - lines = [l for b in self.blocks for l in b.lines] - spans = [s for l in lines for s in l.spans if s.text.strip()] - return spans - - def add_block_types(self, page_block_types): - if len(page_block_types) != len(self.get_all_lines()): - print(f"Warning: Number of detected lines {len(page_block_types)} does not match number of lines {len(self.get_all_lines())}") - - i = 0 - for block in self.blocks: - for line in block.lines: - if i < len(page_block_types): - line_block_type = page_block_types[i].block_type - else: - line_block_type = "Text" - i += 1 - for span in line.spans: - span.block_type = line_block_type - - def get_font_stats(self): - fonts = [s.font for s in self.get_nonblank_spans()] - font_counts = Counter(fonts) - return font_counts - - def get_line_height_stats(self): - heights = [l.bbox[3] - l.bbox[1] for l in self.get_nonblank_lines()] - height_counts = Counter(heights) - return height_counts - - def get_line_start_stats(self): - starts = [l.bbox[0] for l in self.get_nonblank_lines()] - start_counts = Counter(starts) - return start_counts - - def get_min_line_start(self): - starts = [l.bbox[0] for l in self.get_nonblank_lines() if l.spans[0].block_type == "Text"] - if len(starts) == 0: - raise IndexError("No lines found") - return min(starts) - - @property - def prelim_text(self): - return "\n".join([b.prelim_text for b in self.blocks]) - -class MergedLine(BboxElement): - text: str - fonts: List[str] - - def most_common_font(self): - counter = Counter(self.fonts) - return counter.most_common(1)[0][0] - - -class MergedBlock(BboxElement): - lines: List[MergedLine] - pnum: int - block_types: List[str] - - def most_common_block_type(self): - counter = Counter(self.block_types) - return counter.most_common(1)[0][0] - - -class FullyMergedBlock(BaseModel): - text: str - block_type: str diff --git a/marker/bbox.py b/marker/schema/bbox.py similarity index 60% rename from marker/bbox.py rename to marker/schema/bbox.py index a8437b3f..aacded83 100644 --- a/marker/bbox.py +++ b/marker/schema/bbox.py @@ -1,3 +1,8 @@ +from typing import List + +from pydantic import BaseModel, field_validator + + def should_merge_blocks(box1, box2, tol=5): # Within tol y px, and to the right within tol px merge = [ @@ -18,7 +23,7 @@ def boxes_intersect(box1, box2): return box1[0] < box2[2] and box1[2] > box2[0] and box1[1] < box2[3] and box1[3] > box2[1] -def boxes_intersect_pct(box1, box2, pct=.9): +def box_intersection_pct(box1, box2): # determine the coordinates of the intersection rectangle x_left = max(box1[0], box2[0]) y_top = max(box1[1], box2[1]) @@ -28,16 +33,11 @@ def boxes_intersect_pct(box1, box2, pct=.9): if x_right < x_left or y_bottom < y_top: return 0.0 - # The intersection of two axis-aligned bounding boxes is always an - # axis-aligned bounding box intersection_area = (x_right - x_left) * (y_bottom - y_top) - - # compute the area of both AABBs bb1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) - bb2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) - iou = intersection_area / float(bb1_area + bb2_area - intersection_area) - return iou > pct + iou = intersection_area / bb1_area + return iou def multiple_boxes_intersect(box1, boxes): @@ -47,15 +47,44 @@ def multiple_boxes_intersect(box1, boxes): return False -def box_contained(box1, box2): - # Box1 inside box2 - return box1[0] > box2[0] and box1[1] > box2[1] and box1[2] < box2[2] and box1[3] < box2[3] - - def unnormalize_box(bbox, width, height): return [ width * (bbox[0] / 1000), height * (bbox[1] / 1000), width * (bbox[2] / 1000), height * (bbox[3] / 1000), - ] \ No newline at end of file + ] + + +class BboxElement(BaseModel): + bbox: List[float] + + @field_validator('bbox') + @classmethod + def check_4_elements(cls, v: List[float]) -> List[float]: + if len(v) != 4: + raise ValueError('bbox must have 4 elements') + return v + + @property + def height(self): + return self.bbox[3] - self.bbox[1] + + @property + def width(self): + return self.bbox[2] - self.bbox[0] + + @property + def x_start(self): + return self.bbox[0] + + @property + def y_start(self): + return self.bbox[1] + + @property + def area(self): + return self.width * self.height + + def intersection_pct(self, other_bbox: List[float]): + return box_intersection_pct(self.bbox, other_bbox) diff --git a/marker/schema/page.py b/marker/schema/page.py new file mode 100644 index 00000000..835633f7 --- /dev/null +++ b/marker/schema/page.py @@ -0,0 +1,68 @@ +from collections import Counter +from typing import List, Optional + +from marker.schema.bbox import BboxElement +from marker.schema.schema import Block, Span +from surya.schema import TextDetectionResult + + +class Page(BboxElement): + blocks: List[Block] + pnum: int + column_count: Optional[int] = None + rotation: Optional[int] = None # Rotation degrees of the page + text_lines: Optional[TextDetectionResult] = None + + def get_nonblank_lines(self): + lines = self.get_all_lines() + nonblank_lines = [l for l in lines if l.prelim_text.strip()] + return nonblank_lines + + def get_all_lines(self): + lines = [l for b in self.blocks for l in b.lines] + return lines + + def get_nonblank_spans(self) -> List[Span]: + lines = [l for b in self.blocks for l in b.lines] + spans = [s for l in lines for s in l.spans if s.text.strip()] + return spans + + def add_block_types(self, page_block_types): + if len(page_block_types) != len(self.get_all_lines()): + print(f"Warning: Number of detected lines {len(page_block_types)} does not match number of lines {len(self.get_all_lines())}") + + i = 0 + for block in self.blocks: + for line in block.lines: + if i < len(page_block_types): + line_block_type = page_block_types[i].block_type + else: + line_block_type = "Text" + i += 1 + for span in line.spans: + span.block_type = line_block_type + + def get_font_stats(self): + fonts = [s.font for s in self.get_nonblank_spans()] + font_counts = Counter(fonts) + return font_counts + + def get_line_height_stats(self): + heights = [l.bbox[3] - l.bbox[1] for l in self.get_nonblank_lines()] + height_counts = Counter(heights) + return height_counts + + def get_line_start_stats(self): + starts = [l.bbox[0] for l in self.get_nonblank_lines()] + start_counts = Counter(starts) + return start_counts + + def get_min_line_start(self): + starts = [l.bbox[0] for l in self.get_nonblank_lines() if l.spans[0].block_type == "Text"] + if len(starts) == 0: + raise IndexError("No lines found") + return min(starts) + + @property + def prelim_text(self): + return "\n".join([b.prelim_text for b in self.blocks]) diff --git a/marker/schema/schema.py b/marker/schema/schema.py new file mode 100644 index 00000000..bca1df51 --- /dev/null +++ b/marker/schema/schema.py @@ -0,0 +1,111 @@ +from collections import Counter +from typing import List, Optional + +from pydantic import BaseModel, field_validator +import ftfy + +from marker.schema.bbox import multiple_boxes_intersect, BboxElement +from marker.settings import settings + + +class BlockType(BboxElement): + block_type: str + + +class Span(BboxElement): + text: str + span_id: str + font: str + font_weight: float + font_size: float + block_type: Optional[str] = None + + + @field_validator('text') + @classmethod + def fix_unicode(cls, text: str) -> str: + return ftfy.fix_text(text) + + +class Line(BboxElement): + spans: List[Span] + + @property + def prelim_text(self): + return "".join([s.text for s in self.spans]) + + @property + def start(self): + return self.spans[0].bbox[0] + + +class Block(BboxElement): + lines: List[Line] + pnum: int + + @property + def prelim_text(self): + return "\n".join([l.prelim_text for l in self.lines]) + + def contains_equation(self, equation_boxes=None): + conditions = [s.block_type == "Formula" for l in self.lines for s in l.spans] + if equation_boxes: + conditions += [multiple_boxes_intersect(self.bbox, equation_boxes)] + return any(conditions) + + def filter_spans(self, bad_span_ids): + new_lines = [] + for line in self.lines: + new_spans = [] + for span in line.spans: + if not span.span_id in bad_span_ids: + new_spans.append(span) + line.spans = new_spans + if len(new_spans) > 0: + new_lines.append(line) + self.lines = new_lines + + def filter_bad_span_types(self): + new_lines = [] + for line in self.lines: + new_spans = [] + for span in line.spans: + if span.block_type not in settings.BAD_SPAN_TYPES: + new_spans.append(span) + line.spans = new_spans + if len(new_spans) > 0: + new_lines.append(line) + self.lines = new_lines + + def most_common_block_type(self): + counter = Counter([s.block_type for l in self.lines for s in l.spans]) + return counter.most_common(1)[0][0] + + def set_block_type(self, block_type): + for line in self.lines: + for span in line.spans: + span.block_type = block_type + + +class MergedLine(BboxElement): + text: str + fonts: List[str] + + def most_common_font(self): + counter = Counter(self.fonts) + return counter.most_common(1)[0][0] + + +class MergedBlock(BboxElement): + lines: List[MergedLine] + pnum: int + block_types: List[str] + + def most_common_block_type(self): + counter = Counter(self.block_types) + return counter.most_common(1)[0][0] + + +class FullyMergedBlock(BaseModel): + text: str + block_type: str diff --git a/marker/segmentation.py b/marker/segmentation.py index d3a40de3..a7587a51 100644 --- a/marker/segmentation.py +++ b/marker/segmentation.py @@ -1,18 +1,17 @@ -from concurrent.futures import ThreadPoolExecutor from typing import List from transformers import LayoutLMv3ForTokenClassification -from marker.bbox import unnormalize_box +from marker.schema.bbox import unnormalize_box from transformers.models.layoutlmv3.image_processing_layoutlmv3 import normalize_box -import io from PIL import Image from transformers import LayoutLMv3Processor import numpy as np from marker.pdf.images import render_image from marker.settings import settings -from marker.schema import Page, BlockType +from marker.schema.schema import BlockType +from marker.schema.page import Page import torch from math import isclose diff --git a/marker/settings.py b/marker/settings.py index fe42004c..d4a6c048 100644 --- a/marker/settings.py +++ b/marker/settings.py @@ -37,38 +37,34 @@ def TORCH_DEVICE_MODEL(self) -> str: #"application/x-fictionbook+xml": "fb2" } + # Text line Detection + DETECTOR_BATCH_SIZE: Optional[int] = None + SURYA_DETECTOR_DPI: int = 96 + DETECTOR_POSTPROCESSING_CPU_WORKERS: int = 4 + # OCR INVALID_CHARS: List[str] = [chr(0xfffd), "�"] - OCR_DPI: int = 400 - TESSDATA_PREFIX: str = "" - TESSERACT_LANGUAGES: Dict = { - "English": "eng", - "Spanish": "spa", - "Portuguese": "por", - "French": "fra", - "German": "deu", - "Russian": "rus", - "Chinese": "chi_sim", - "Japanese": "jpn", - "Korean": "kor", - "Hindi": "hin", - } - TESSERACT_TIMEOUT: int = 20 # When to give up on OCR - SPELLCHECK_LANGUAGES: Dict = { - "English": "en", - "Spanish": "es", - "Portuguese": "pt", - "French": "fr", - "German": "de", - "Russian": "ru", - "Chinese": None, - "Japanese": None, - "Korean": None, - "Hindi": None, - } + OCR_ENGINE: Optional[str] = None # Which OCR engine to use, either "surya" or "ocrmypdf". Defaults to "ocrmypdf" on CPU, "surya" on GPU. OCR_ALL_PAGES: bool = False # Run OCR on every page even if text can be extracted + + ## Surya + SURYA_OCR_DPI: int = 96 + RECOGNITION_BATCH_SIZE: Optional[int] = None # Batch size for surya OCR + + ## Tesseract OCR_PARALLEL_WORKERS: int = 2 # How many CPU workers to use for OCR - OCR_ENGINE: str = "ocrmypdf" # Which OCR engine to use, either "tesseract" or "ocrmypdf". Ocrmypdf is higher quality, but slower. + TESSERACT_TIMEOUT: int = 20 # When to give up on OCR + + @computed_field + def OCR_ENGINE_INTERNAL(self) -> str: + if self.OCR_ENGINE is not None: + return self.OCR_ENGINE + + # Does not work with mps + if torch.cuda.is_available(): + return "surya" + + return "ocrmypdf" # Texify model TEXIFY_MODEL_MAX: int = 384 # Max inference length for texify @@ -82,7 +78,7 @@ def TORCH_DEVICE_MODEL(self) -> str: LAYOUT_MODEL_MAX: int = 512 LAYOUT_CHUNK_OVERLAP: int = 64 LAYOUT_DPI: int = 96 - LAYOUT_MODEL_NAME: str = "vikp/layout_segmenter" + LAYOUT_MODEL_CHECKPOINT: str = "vikp/layout_segmenter" LAYOUT_BATCH_SIZE: int = 8 # Max 512 tokens means high batch size # Ordering model