diff --git a/marker/v2/builders/document.py b/marker/v2/builders/document.py index a89b013a..af5d9f44 100644 --- a/marker/v2/builders/document.py +++ b/marker/v2/builders/document.py @@ -3,8 +3,10 @@ from marker.v2.builders.layout import LayoutBuilder from marker.v2.builders.ocr import OcrBuilder from marker.v2.providers.pdf import PdfProvider +from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup +from marker.v2.schema.registry import get_block_class class DocumentBuilder(BaseBuilder): @@ -15,13 +17,14 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_bui return document def build_document(self, provider: PdfProvider): + PageGroupClass: PageGroup = get_block_class(BlockTypes.Page) initial_pages = [ - PageGroup( + PageGroupClass( page_id=i, lowres_image=provider.get_image(i, settings.IMAGE_DPI), highres_image=provider.get_image(i, settings.HIGHRES_IMAGE_DPI), polygon=provider.get_page_bbox(i) ) for i in provider.page_range ] - - return Document(filepath=provider.filepath, pages=initial_pages) + DocumentClass: Document = get_block_class(BlockTypes.Document) + return DocumentClass(filepath=provider.filepath, pages=initial_pages) diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 4683e890..0b4df842 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -7,10 +7,10 @@ from marker.v2.builders import BaseBuilder from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider from marker.v2.schema import BlockTypes -from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup from marker.v2.schema.polygon import PolygonBox +from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line @@ -49,7 +49,7 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size provider_page_size = page.polygon.size for bbox in sorted(layout_result.bboxes, key=lambda x: x.position): - block_cls = LAYOUT_BLOCK_REGISTRY[BlockTypes[bbox.label]] + block_cls = get_block_class(BlockTypes[bbox.label]) layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon)) layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size) page.add_structure(layout_block) diff --git a/marker/v2/builders/ocr.py b/marker/v2/builders/ocr.py index bcf757a4..9f40da9a 100644 --- a/marker/v2/builders/ocr.py +++ b/marker/v2/builders/ocr.py @@ -6,9 +6,9 @@ from marker.v2.builders import BaseBuilder from marker.v2.providers.pdf import PdfProvider from marker.v2.schema import BlockTypes -from marker.v2.schema.blocks import Block from marker.v2.schema.document import Document from marker.v2.schema.polygon import PolygonBox +from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span @@ -63,6 +63,9 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag page_lines = {} page_spans = {} + SpanClass: Span = get_block_class(BlockTypes.Span) + LineClass: Line = get_block_class(BlockTypes.Line) + for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results): page_spans.setdefault(page_id, {}) page_lines.setdefault(page_id, []) @@ -74,13 +77,13 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox) polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size) - page_lines[page_id].append(Line( + page_lines[page_id].append(LineClass( polygon=polygon, page_id=page_id, )) line_spans.setdefault(ocr_line_idx, []) - line_spans[ocr_line_idx].append(Span( + line_spans[ocr_line_idx].append(SpanClass( text=ocr_line.text, formats=['plain'], page_id=page_id, diff --git a/marker/v2/builders/structure.py b/marker/v2/builders/structure.py index 90f950ce..b316ddb4 100644 --- a/marker/v2/builders/structure.py +++ b/marker/v2/builders/structure.py @@ -1,12 +1,9 @@ -from typing import Optional - -from pydantic import BaseModel - from marker.v2.builders import BaseBuilder from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document -from marker.v2.schema.groups import GROUP_BLOCK_REGISTRY, ListGroup +from marker.v2.schema.groups import ListGroup from marker.v2.schema.groups.page import PageGroup +from marker.v2.schema.registry import get_block_class class StructureBuilder(BaseBuilder): @@ -53,7 +50,7 @@ def group_caption_blocks(self, page: PageGroup): if len(block_structure) > 1: # Create a merged block - new_block_cls = GROUP_BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]] + new_block_cls = get_block_class(BlockTypes[block.block_type.name + "Group"]) new_polygon = block.polygon.merge(selected_polygons) group_block = page.add_block(new_block_cls, new_polygon) group_block.structure = block_structure diff --git a/marker/v2/converters/pdf.py b/marker/v2/converters/pdf.py index 47135819..55caad9e 100644 --- a/marker/v2/converters/pdf.py +++ b/marker/v2/converters/pdf.py @@ -1,10 +1,11 @@ +from marker.v2.providers.pdf import PdfProvider import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning -from marker.v2.processors.sectionheader import SectionHeaderProcessor -from marker.v2.providers.pdf import PdfProvider import tempfile +from collections import defaultdict +from typing import Dict, Type import click import datasets @@ -14,17 +15,26 @@ from marker.v2.builders.ocr import OcrBuilder from marker.v2.builders.structure import StructureBuilder from marker.v2.converters import BaseConverter +from marker.v2.models import setup_detection_model, setup_layout_model, \ + setup_recognition_model, setup_table_rec_model, setup_texify_model from marker.v2.processors.equation import EquationProcessor +from marker.v2.processors.sectionheader import SectionHeaderProcessor from marker.v2.processors.table import TableProcessor -from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \ - setup_detection_model from marker.v2.renderers.markdown import MarkdownRenderer +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import Block +from marker.v2.schema.registry import register_block_class from marker.v2.processors.debug import DebugProcessor class PdfConverter(BaseConverter): + override_map: Dict[BlockTypes, Type[Block]] = defaultdict() + def __init__(self, config=None): super().__init__(config) + + for block_type, override_block_type in self.override_map.items(): + register_block_class(block_type, override_block_type) self.layout_model = setup_layout_model() self.texify_model = setup_texify_model() diff --git a/marker/v2/providers/pdf.py b/marker/v2/providers/pdf.py index d4492ef1..77d49c2f 100644 --- a/marker/v2/providers/pdf.py +++ b/marker/v2/providers/pdf.py @@ -4,11 +4,12 @@ import pypdfium2 as pdfium from pdftext.extraction import dictionary_output from PIL import Image -from pydantic import BaseModel from marker.ocr.heuristics import detect_bad_ocr from marker.v2.providers import BaseProvider from marker.v2.schema.polygon import PolygonBox +from marker.v2.schema import BlockTypes +from marker.v2.schema.registry import get_block_class from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span @@ -23,7 +24,7 @@ class PdfProvider(BaseProvider): flatten_pdf: bool = True force_ocr: bool = False - def __init__(self, filepath: str, config = None): + def __init__(self, filepath: str, config=None): super().__init__(filepath, config) self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath) @@ -105,6 +106,8 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: workers=self.pdftext_workers, flatten_pdf=self.flatten_pdf ) + SpanClass: Span = get_block_class(BlockTypes.Span) + LineClass: Span = get_block_class(BlockTypes.Line) for page in page_char_blocks: page_id = page["page"] lines: List[Line] = [] @@ -120,7 +123,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: font_weight = span["font"]["weight"] or 0 font_size = span["font"]["size"] or 0 spans.append( - Span( + SpanClass( polygon=PolygonBox.from_bbox(span["bbox"]), text=span["text"], font=font_name, @@ -133,7 +136,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: text_extraction_method="pdftext" ) ) - lines.append(Line(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id)) + lines.append(LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id)) line_spans[len(lines) - 1] = spans if self.check_line_spans(line_spans): page_lines[page_id] = lines diff --git a/marker/v2/schema/blocks/__init__.py b/marker/v2/schema/blocks/__init__.py index d10f8759..09db403c 100644 --- a/marker/v2/schema/blocks/__init__.py +++ b/marker/v2/schema/blocks/__init__.py @@ -17,11 +17,3 @@ from marker.v2.schema.blocks.table import Table from marker.v2.schema.blocks.text import Text from marker.v2.schema.blocks.toc import TableOfContents - -LAYOUT_BLOCK_REGISTRY = { - v.model_fields['block_type'].default: v for k, v in locals().items() - if isinstance(v, type) - and issubclass(v, Block) - and v != Block # Exclude the base Block class - and not v.model_fields['block_type'].default.name.endswith("Group") -} diff --git a/marker/v2/schema/groups/__init__.py b/marker/v2/schema/groups/__init__.py index a676ede2..43bf9829 100644 --- a/marker/v2/schema/groups/__init__.py +++ b/marker/v2/schema/groups/__init__.py @@ -4,11 +4,3 @@ from marker.v2.schema.groups.list import ListGroup from marker.v2.schema.groups.picture import PictureGroup from marker.v2.schema.groups.page import PageGroup - -GROUP_BLOCK_REGISTRY = { - v.model_fields['block_type'].default: v for k, v in locals().items() - if isinstance(v, type) - and issubclass(v, Block) - and v != Block # Exclude the base Block class - and (v.model_fields['block_type'].default.name.endswith("Group") or v.model_fields['block_type'].default.name == "Page") -} diff --git a/marker/v2/schema/registry.py b/marker/v2/schema/registry.py new file mode 100644 index 00000000..bf17f99c --- /dev/null +++ b/marker/v2/schema/registry.py @@ -0,0 +1,55 @@ +from typing import Dict, Type +from importlib import import_module + +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import Block, Caption, Code, Equation, Figure, \ + Footnote, Form, Handwriting, InlineMath, \ + ListItem, PageFooter, PageHeader, Picture, \ + SectionHeader, Table, TableOfContents, \ + Text +from marker.v2.schema.document import Document +from marker.v2.schema.groups import FigureGroup, ListGroup, PageGroup, \ + PictureGroup, TableGroup +from marker.v2.schema.text import Line, Span + +BLOCK_REGISTRY: Dict[BlockTypes, str] = {} + + +def register_block_class(block_type: BlockTypes, block_cls: Type[Block]): + BLOCK_REGISTRY[block_type] = f"{block_cls.__module__}.{block_cls.__name__}" + + +def get_block_class(block_type: BlockTypes) -> Type[Block]: + class_path = BLOCK_REGISTRY[block_type] + module_name, class_name = class_path.rsplit('.', 1) + module = import_module(module_name) + return getattr(module, class_name) + + +register_block_class(BlockTypes.Line, Line) +register_block_class(BlockTypes.Span, Span) +register_block_class(BlockTypes.FigureGroup, FigureGroup) +register_block_class(BlockTypes.TableGroup, TableGroup) +register_block_class(BlockTypes.ListGroup, ListGroup) +register_block_class(BlockTypes.PictureGroup, PictureGroup) +register_block_class(BlockTypes.Page, PageGroup) +register_block_class(BlockTypes.Caption, Caption) +register_block_class(BlockTypes.Code, Code) +register_block_class(BlockTypes.Figure, Figure) +register_block_class(BlockTypes.Footnote, Footnote) +register_block_class(BlockTypes.Form, Form) +register_block_class(BlockTypes.Equation, Equation) +register_block_class(BlockTypes.Handwriting, Handwriting) +register_block_class(BlockTypes.TextInlineMath, InlineMath) +register_block_class(BlockTypes.ListItem, ListItem) +register_block_class(BlockTypes.PageFooter, PageFooter) +register_block_class(BlockTypes.PageHeader, PageHeader) +register_block_class(BlockTypes.Picture, Picture) +register_block_class(BlockTypes.SectionHeader, SectionHeader) +register_block_class(BlockTypes.Table, Table) +register_block_class(BlockTypes.Text, Text) +register_block_class(BlockTypes.TableOfContents, TableOfContents) +register_block_class(BlockTypes.Document, Document) + +assert len(BLOCK_REGISTRY) == len(BlockTypes) +assert all([get_block_class(k).model_fields['block_type'].default == k for k, _ in BLOCK_REGISTRY.items()]) diff --git a/marker/v2/schema/text/__init__.py b/marker/v2/schema/text/__init__.py index 64015b0d..8901b65f 100644 --- a/marker/v2/schema/text/__init__.py +++ b/marker/v2/schema/text/__init__.py @@ -1,8 +1,2 @@ -from marker.v2.schema import BlockTypes from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span - -TEXT_BLOCK_REGISTRY = { - BlockTypes.Line: Line, - BlockTypes.Span: Span, -} diff --git a/poetry.lock b/poetry.lock index e4e8dfcb..c4c900ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1716,6 +1716,21 @@ profiling = ["gprof2dot"] rtd = ["jupyter_sphinx", "mdit-py-plugins", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx_book_theme"] testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] +[[package]] +name = "markdownify" +version = "0.13.1" +description = "Convert HTML to markdown." +optional = false +python-versions = "*" +files = [ + {file = "markdownify-0.13.1-py3-none-any.whl", hash = "sha256:1d181d43d20902bcc69d7be85b5316ed174d0dda72ff56e14ae4c95a4a407d22"}, + {file = "markdownify-0.13.1.tar.gz", hash = "sha256:ab257f9e6bd4075118828a28c9d02f8a4bfeb7421f558834aa79b2dfeb32a098"}, +] + +[package.dependencies] +beautifulsoup4 = ">=4.9,<5" +six = ">=1.15,<2" + [[package]] name = "markupsafe" version = "3.0.2" @@ -5208,4 +5223,4 @@ propcache = ">=0.2.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "87f2bf6e84db6c7100db24777a4104096cb963d9bea7380e520dbc1bd8f2feb7" +content-hash = "4cb4a1d2f40994498d5657bcb5ab37e7ec9c4bb867015d34131c53000eacfbee" diff --git a/pyproject.toml b/pyproject.toml index d2b49103..233e6ecd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ filetype = "^1.2.0" regex = "^2024.4.28" pdftext = "^0.3.18" tabled-pdf = { git = "https://github.com/VikParuchuri/tabled.git", branch = "dev-mose/compilation-updates" } +markdownify = "^0.13.1" [tool.poetry.group.dev.dependencies] jupyter = "^1.0.0" diff --git a/tests/conftest.py b/tests/conftest.py index 178470a6..31335a1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,13 +3,17 @@ import datasets import pytest +from typing import Dict, Type +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import Block from marker.v2.models import setup_layout_model, setup_texify_model, setup_recognition_model, setup_table_rec_model, \ setup_detection_model from marker.v2.builders.document import DocumentBuilder from marker.v2.builders.layout import LayoutBuilder from marker.v2.builders.ocr import OcrBuilder from marker.v2.schema.document import Document +from marker.v2.schema.registry import register_block_class @pytest.fixture(scope="session") @@ -48,13 +52,22 @@ def table_rec_model(): @pytest.fixture(scope="function") -def pdf_provider(request): +def config(request): + config_mark = request.node.get_closest_marker("config") + config = config_mark.args[0] if config_mark else {} + + override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {}) + for block_type, override_block_type in override_map.items(): + register_block_class(block_type, override_block_type) + + return config + + +@pytest.fixture(scope="function") +def pdf_provider(request, config): filename_mark = request.node.get_closest_marker("filename") filename = filename_mark.args[0] if filename_mark else "adversarial.pdf" - config_mark = request.node.get_closest_marker("config") - config = config_mark.args[0] if config_mark else None - dataset = datasets.load_dataset("datalab-to/pdfs", split="train") idx = dataset['filename'].index(filename) @@ -65,10 +78,7 @@ def pdf_provider(request): @pytest.fixture(scope="function") -def pdf_document(request, pdf_provider, layout_model, recognition_model, detection_model) -> Document: - config_mark = request.node.get_closest_marker("config") - config = config_mark.args[0] if config_mark else None - +def pdf_document(request, config, pdf_provider, layout_model, recognition_model, detection_model) -> Document: layout_builder = LayoutBuilder(layout_model, config) ocr_builder = OcrBuilder(detection_model, recognition_model, config) builder = DocumentBuilder(config) diff --git a/tests/test_overriding.py b/tests/test_overriding.py new file mode 100644 index 00000000..01c4e588 --- /dev/null +++ b/tests/test_overriding.py @@ -0,0 +1,49 @@ +import multiprocessing as mp + +import pytest + +from marker.v2.providers.pdf import PdfProvider +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import SectionHeader +from marker.v2.schema.document import Document +from marker.v2.schema.registry import register_block_class +from marker.v2.schema.text import Line +from tests.utils import setup_pdf_provider + + +class NewSectionHeader(SectionHeader): + pass + + +class NewLine(Line): + pass + + +@pytest.mark.config({ + "page_range": [0], + "override_map": {BlockTypes.SectionHeader: NewSectionHeader} +}) +def test_overriding(pdf_document: Document): + assert pdf_document.pages[0]\ + .get_block(pdf_document.pages[0].structure[0]).__class__ == NewSectionHeader + + +def get_lines(pdf: str, config=None): + provider: PdfProvider = setup_pdf_provider(pdf, config) + return provider.get_page_lines(0) + + +def test_overriding_mp(): + config = { + "page_range": [0], + "override_map": {BlockTypes.Line: NewLine} + } + + for block_type, block_cls in config["override_map"].items(): + register_block_class(block_type, block_cls) + + pdf_list = ["adversarial.pdf", "adversarial_rot.pdf"] + + with mp.Pool(processes=2) as pool: + results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list]) + assert all([r[0].__class__ == NewLine for r in results]) diff --git a/tests/utils.py b/tests/utils.py index b2e01db5..1f7a2911 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,10 +9,10 @@ from marker.v2.schema.document import Document -def setup_pdf_document( +def setup_pdf_provider( filename='adversarial.pdf', config=None, -) -> Document: +) -> PdfProvider: dataset = datasets.load_dataset("datalab-to/pdfs", split="train") idx = dataset['filename'].index(filename) @@ -20,11 +20,19 @@ def setup_pdf_document( temp_pdf.write(dataset['pdf'][idx]) temp_pdf.flush() + provider = PdfProvider(temp_pdf.name, config) + return provider + + +def setup_pdf_document( + filename='adversarial.pdf', + config=None, +) -> Document: layout_model = setup_layout_model() recognition_model = setup_recognition_model() detection_model = setup_detection_model() - provider = PdfProvider(temp_pdf.name, config) + provider = setup_pdf_provider(filename, config) layout_builder = LayoutBuilder(layout_model, config) ocr_builder = OcrBuilder(detection_model, recognition_model, config) builder = DocumentBuilder(config)