diff --git a/marker/v2/builders/document.py b/marker/v2/builders/document.py index a89b013a..1ea0a057 100644 --- a/marker/v2/builders/document.py +++ b/marker/v2/builders/document.py @@ -5,6 +5,7 @@ from marker.v2.providers.pdf import PdfProvider from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup +from marker.v2.schema.registry import get_block_cls class DocumentBuilder(BaseBuilder): @@ -15,13 +16,14 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_bui return document def build_document(self, provider: PdfProvider): + PageGroupClass = get_block_cls(PageGroup) initial_pages = [ - PageGroup( + PageGroupClass( page_id=i, lowres_image=provider.get_image(i, settings.IMAGE_DPI), highres_image=provider.get_image(i, settings.HIGHRES_IMAGE_DPI), polygon=provider.get_page_bbox(i) ) for i in provider.page_range ] - - return Document(filepath=provider.filepath, pages=initial_pages) + DocumentClass = get_block_cls(Document) + return DocumentClass(filepath=provider.filepath, pages=initial_pages) diff --git a/marker/v2/builders/layout.py b/marker/v2/builders/layout.py index 4683e890..5cf7842a 100644 --- a/marker/v2/builders/layout.py +++ b/marker/v2/builders/layout.py @@ -7,10 +7,10 @@ from marker.v2.builders import BaseBuilder from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider from marker.v2.schema import BlockTypes -from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY from marker.v2.schema.document import Document from marker.v2.schema.groups.page import PageGroup from marker.v2.schema.polygon import PolygonBox +from marker.v2.schema.registry import BLOCK_REGISTRY from marker.v2.schema.text.line import Line @@ -49,7 +49,7 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size provider_page_size = page.polygon.size for bbox in sorted(layout_result.bboxes, key=lambda x: x.position): - block_cls = LAYOUT_BLOCK_REGISTRY[BlockTypes[bbox.label]] + block_cls = BLOCK_REGISTRY[BlockTypes[bbox.label]] layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon)) layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size) page.add_structure(layout_block) diff --git a/marker/v2/builders/ocr.py b/marker/v2/builders/ocr.py index bcf757a4..1e97afb9 100644 --- a/marker/v2/builders/ocr.py +++ b/marker/v2/builders/ocr.py @@ -5,10 +5,9 @@ from marker.settings import settings from marker.v2.builders import BaseBuilder from marker.v2.providers.pdf import PdfProvider -from marker.v2.schema import BlockTypes -from marker.v2.schema.blocks import Block from marker.v2.schema.document import Document from marker.v2.schema.polygon import PolygonBox +from marker.v2.schema.registry import get_block_cls from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span @@ -63,6 +62,9 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag page_lines = {} page_spans = {} + SpanClass = get_block_cls(Span) + LineClass = get_block_cls(Line) + for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results): page_spans.setdefault(page_id, {}) page_lines.setdefault(page_id, []) @@ -74,13 +76,13 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox) polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size) - page_lines[page_id].append(Line( + page_lines[page_id].append(LineClass( polygon=polygon, page_id=page_id, )) line_spans.setdefault(ocr_line_idx, []) - line_spans[ocr_line_idx].append(Span( + line_spans[ocr_line_idx].append(SpanClass( text=ocr_line.text, formats=['plain'], page_id=page_id, diff --git a/marker/v2/builders/structure.py b/marker/v2/builders/structure.py index 2b9ac994..25137295 100644 --- a/marker/v2/builders/structure.py +++ b/marker/v2/builders/structure.py @@ -1,12 +1,9 @@ -from typing import Optional - -from pydantic import BaseModel - from marker.v2.builders import BaseBuilder from marker.v2.schema import BlockTypes from marker.v2.schema.document import Document -from marker.v2.schema.groups import GROUP_BLOCK_REGISTRY, ListGroup +from marker.v2.schema.groups import ListGroup from marker.v2.schema.groups.page import PageGroup +from marker.v2.schema.registry import BLOCK_REGISTRY class StructureBuilder(BaseBuilder): @@ -52,7 +49,7 @@ def group_caption_blocks(self, page: PageGroup): if len(block_structure) > 1: # Create a merged block - new_block_cls = GROUP_BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]] + new_block_cls = BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]] new_polygon = block.polygon.merge(selected_polygons) group_block = page.add_block(new_block_cls, new_polygon) group_block.structure = block_structure diff --git a/marker/v2/providers/pdf.py b/marker/v2/providers/pdf.py index d4492ef1..55aba74d 100644 --- a/marker/v2/providers/pdf.py +++ b/marker/v2/providers/pdf.py @@ -4,11 +4,11 @@ import pypdfium2 as pdfium from pdftext.extraction import dictionary_output from PIL import Image -from pydantic import BaseModel from marker.ocr.heuristics import detect_bad_ocr from marker.v2.providers import BaseProvider from marker.v2.schema.polygon import PolygonBox +from marker.v2.schema.registry import get_block_cls from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span @@ -23,7 +23,7 @@ class PdfProvider(BaseProvider): flatten_pdf: bool = True force_ocr: bool = False - def __init__(self, filepath: str, config = None): + def __init__(self, filepath: str, config=None): super().__init__(filepath, config) self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath) @@ -105,6 +105,8 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: workers=self.pdftext_workers, flatten_pdf=self.flatten_pdf ) + SpanClass = get_block_cls(Span) + LineClass = get_block_cls(Line) for page in page_char_blocks: page_id = page["page"] lines: List[Line] = [] @@ -120,7 +122,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: font_weight = span["font"]["weight"] or 0 font_size = span["font"]["size"] or 0 spans.append( - Span( + SpanClass( polygon=PolygonBox.from_bbox(span["bbox"]), text=span["text"], font=font_name, @@ -133,7 +135,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]: text_extraction_method="pdftext" ) ) - lines.append(Line(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id)) + lines.append(LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id)) line_spans[len(lines) - 1] = spans if self.check_line_spans(line_spans): page_lines[page_id] = lines diff --git a/marker/v2/schema/blocks/__init__.py b/marker/v2/schema/blocks/__init__.py index d10f8759..09db403c 100644 --- a/marker/v2/schema/blocks/__init__.py +++ b/marker/v2/schema/blocks/__init__.py @@ -17,11 +17,3 @@ from marker.v2.schema.blocks.table import Table from marker.v2.schema.blocks.text import Text from marker.v2.schema.blocks.toc import TableOfContents - -LAYOUT_BLOCK_REGISTRY = { - v.model_fields['block_type'].default: v for k, v in locals().items() - if isinstance(v, type) - and issubclass(v, Block) - and v != Block # Exclude the base Block class - and not v.model_fields['block_type'].default.name.endswith("Group") -} diff --git a/marker/v2/schema/groups/__init__.py b/marker/v2/schema/groups/__init__.py index a676ede2..43bf9829 100644 --- a/marker/v2/schema/groups/__init__.py +++ b/marker/v2/schema/groups/__init__.py @@ -4,11 +4,3 @@ from marker.v2.schema.groups.list import ListGroup from marker.v2.schema.groups.picture import PictureGroup from marker.v2.schema.groups.page import PageGroup - -GROUP_BLOCK_REGISTRY = { - v.model_fields['block_type'].default: v for k, v in locals().items() - if isinstance(v, type) - and issubclass(v, Block) - and v != Block # Exclude the base Block class - and (v.model_fields['block_type'].default.name.endswith("Group") or v.model_fields['block_type'].default.name == "Page") -} diff --git a/marker/v2/schema/registry.py b/marker/v2/schema/registry.py new file mode 100644 index 00000000..e7337465 --- /dev/null +++ b/marker/v2/schema/registry.py @@ -0,0 +1,49 @@ +from typing import Dict, Type, TypeVar + +from marker.v2.schema import BlockTypes +from marker.v2.schema.blocks import Block, Caption, Code, Equation, Figure, \ + Footnote, Form, Handwriting, InlineMath, \ + ListItem, PageFooter, PageHeader, Picture, \ + SectionHeader, Table, TableOfContents, \ + Text +from marker.v2.schema.document import Document +from marker.v2.schema.groups import FigureGroup, ListGroup, PageGroup, \ + PictureGroup, TableGroup +from marker.v2.schema.text import Line, Span + +BLOCK_REGISTRY: Dict[str, Type[Block]] = { + BlockTypes.Line: Line, + BlockTypes.Span: Span, + BlockTypes.FigureGroup: FigureGroup, + BlockTypes.TableGroup: TableGroup, + BlockTypes.ListGroup: ListGroup, + BlockTypes.PictureGroup: PictureGroup, + BlockTypes.Page: PageGroup, + BlockTypes.Caption: Caption, + BlockTypes.Code: Code, + BlockTypes.Figure: Figure, + BlockTypes.Footnote: Footnote, + BlockTypes.Form: Form, + BlockTypes.Equation: Equation, + BlockTypes.Handwriting: Handwriting, + BlockTypes.TextInlineMath: InlineMath, + BlockTypes.ListItem: ListItem, + BlockTypes.PageFooter: PageFooter, + BlockTypes.PageHeader: PageHeader, + BlockTypes.Picture: Picture, + BlockTypes.SectionHeader: SectionHeader, + BlockTypes.Table: Table, + BlockTypes.Text: Text, + BlockTypes.TableOfContents: TableOfContents, + BlockTypes.Document: Document, +} + +T = TypeVar('T') + + +def get_block_cls(block_cls: T) -> T: + return BLOCK_REGISTRY.get(block_cls.model_fields['block_type'].default, block_cls) + + +assert len(BLOCK_REGISTRY) == len(BlockTypes) +assert all([v.model_fields['block_type'].default == k for k, v in BLOCK_REGISTRY.items()]) diff --git a/marker/v2/schema/text/__init__.py b/marker/v2/schema/text/__init__.py index 64015b0d..8901b65f 100644 --- a/marker/v2/schema/text/__init__.py +++ b/marker/v2/schema/text/__init__.py @@ -1,8 +1,2 @@ -from marker.v2.schema import BlockTypes from marker.v2.schema.text.line import Line from marker.v2.schema.text.span import Span - -TEXT_BLOCK_REGISTRY = { - BlockTypes.Line: Line, - BlockTypes.Span: Span, -}