Skip to content

Commit

Permalink
static registry + ability to override nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Nov 18, 2024
1 parent 9a793c8 commit b674363
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 41 deletions.
8 changes: 5 additions & 3 deletions marker/v2/builders/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.registry import get_block_cls


class DocumentBuilder(BaseBuilder):
Expand All @@ -15,13 +16,14 @@ def __call__(self, provider: PdfProvider, layout_builder: LayoutBuilder, ocr_bui
return document

def build_document(self, provider: PdfProvider):
PageGroupClass = get_block_cls(PageGroup)
initial_pages = [
PageGroup(
PageGroupClass(
page_id=i,
lowres_image=provider.get_image(i, settings.IMAGE_DPI),
highres_image=provider.get_image(i, settings.HIGHRES_IMAGE_DPI),
polygon=provider.get_page_bbox(i)
) for i in provider.page_range
]

return Document(filepath=provider.filepath, pages=initial_pages)
DocumentClass = get_block_cls(Document)
return DocumentClass(filepath=provider.filepath, pages=initial_pages)
4 changes: 2 additions & 2 deletions marker/v2/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from marker.v2.builders import BaseBuilder
from marker.v2.providers.pdf import PageLines, PageSpans, PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import LAYOUT_BLOCK_REGISTRY
from marker.v2.schema.document import Document
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import BLOCK_REGISTRY
from marker.v2.schema.text.line import Line


Expand Down Expand Up @@ -49,7 +49,7 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
provider_page_size = page.polygon.size
for bbox in sorted(layout_result.bboxes, key=lambda x: x.position):
block_cls = LAYOUT_BLOCK_REGISTRY[BlockTypes[bbox.label]]
block_cls = BLOCK_REGISTRY[BlockTypes[bbox.label]]
layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon))
layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size)
page.add_structure(layout_block)
Expand Down
10 changes: 6 additions & 4 deletions marker/v2/builders/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from marker.settings import settings
from marker.v2.builders import BaseBuilder
from marker.v2.providers.pdf import PdfProvider
from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block
from marker.v2.schema.document import Document
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import get_block_cls
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

Expand Down Expand Up @@ -63,6 +62,9 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag
page_lines = {}
page_spans = {}

SpanClass = get_block_cls(Span)
LineClass = get_block_cls(Line)

for page_id, recognition_result in zip((page.page_id for page in page_list), recognition_results):
page_spans.setdefault(page_id, {})
page_lines.setdefault(page_id, [])
Expand All @@ -74,13 +76,13 @@ def ocr_extraction(self, document: Document, provider: PdfProvider) -> Tuple[Pag
image_polygon = PolygonBox.from_bbox(recognition_result.image_bbox)
polygon = PolygonBox.from_bbox(ocr_line.bbox).rescale(image_polygon.size, page_size)

page_lines[page_id].append(Line(
page_lines[page_id].append(LineClass(
polygon=polygon,
page_id=page_id,
))

line_spans.setdefault(ocr_line_idx, [])
line_spans[ocr_line_idx].append(Span(
line_spans[ocr_line_idx].append(SpanClass(
text=ocr_line.text,
formats=['plain'],
page_id=page_id,
Expand Down
9 changes: 3 additions & 6 deletions marker/v2/builders/structure.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional

from pydantic import BaseModel

from marker.v2.builders import BaseBuilder
from marker.v2.schema import BlockTypes
from marker.v2.schema.document import Document
from marker.v2.schema.groups import GROUP_BLOCK_REGISTRY, ListGroup
from marker.v2.schema.groups import ListGroup
from marker.v2.schema.groups.page import PageGroup
from marker.v2.schema.registry import BLOCK_REGISTRY


class StructureBuilder(BaseBuilder):
Expand Down Expand Up @@ -52,7 +49,7 @@ def group_caption_blocks(self, page: PageGroup):

if len(block_structure) > 1:
# Create a merged block
new_block_cls = GROUP_BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]]
new_block_cls = BLOCK_REGISTRY[BlockTypes[block.block_type.name + "Group"]]
new_polygon = block.polygon.merge(selected_polygons)
group_block = page.add_block(new_block_cls, new_polygon)
group_block.structure = block_structure
Expand Down
10 changes: 6 additions & 4 deletions marker/v2/providers/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import pypdfium2 as pdfium
from pdftext.extraction import dictionary_output
from PIL import Image
from pydantic import BaseModel

from marker.ocr.heuristics import detect_bad_ocr
from marker.v2.providers import BaseProvider
from marker.v2.schema.polygon import PolygonBox
from marker.v2.schema.registry import get_block_cls
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

Expand All @@ -23,7 +23,7 @@ class PdfProvider(BaseProvider):
flatten_pdf: bool = True
force_ocr: bool = False

def __init__(self, filepath: str, config = None):
def __init__(self, filepath: str, config=None):
super().__init__(filepath, config)

self.doc: pdfium.PdfDocument = pdfium.PdfDocument(self.filepath)
Expand Down Expand Up @@ -105,6 +105,8 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
workers=self.pdftext_workers,
flatten_pdf=self.flatten_pdf
)
SpanClass = get_block_cls(Span)
LineClass = get_block_cls(Line)
for page in page_char_blocks:
page_id = page["page"]
lines: List[Line] = []
Expand All @@ -120,7 +122,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
font_weight = span["font"]["weight"] or 0
font_size = span["font"]["size"] or 0
spans.append(
Span(
SpanClass(
polygon=PolygonBox.from_bbox(span["bbox"]),
text=span["text"],
font=font_name,
Expand All @@ -133,7 +135,7 @@ def pdftext_extraction(self) -> Tuple[PageLines, PageSpans]:
text_extraction_method="pdftext"
)
)
lines.append(Line(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id))
lines.append(LineClass(polygon=PolygonBox.from_bbox(line["bbox"]), page_id=page_id))
line_spans[len(lines) - 1] = spans
if self.check_line_spans(line_spans):
page_lines[page_id] = lines
Expand Down
8 changes: 0 additions & 8 deletions marker/v2/schema/blocks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,3 @@
from marker.v2.schema.blocks.table import Table
from marker.v2.schema.blocks.text import Text
from marker.v2.schema.blocks.toc import TableOfContents

LAYOUT_BLOCK_REGISTRY = {
v.model_fields['block_type'].default: v for k, v in locals().items()
if isinstance(v, type)
and issubclass(v, Block)
and v != Block # Exclude the base Block class
and not v.model_fields['block_type'].default.name.endswith("Group")
}
8 changes: 0 additions & 8 deletions marker/v2/schema/groups/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,3 @@
from marker.v2.schema.groups.list import ListGroup
from marker.v2.schema.groups.picture import PictureGroup
from marker.v2.schema.groups.page import PageGroup

GROUP_BLOCK_REGISTRY = {
v.model_fields['block_type'].default: v for k, v in locals().items()
if isinstance(v, type)
and issubclass(v, Block)
and v != Block # Exclude the base Block class
and (v.model_fields['block_type'].default.name.endswith("Group") or v.model_fields['block_type'].default.name == "Page")
}
49 changes: 49 additions & 0 deletions marker/v2/schema/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, Type, TypeVar

from marker.v2.schema import BlockTypes
from marker.v2.schema.blocks import Block, Caption, Code, Equation, Figure, \
Footnote, Form, Handwriting, InlineMath, \
ListItem, PageFooter, PageHeader, Picture, \
SectionHeader, Table, TableOfContents, \
Text
from marker.v2.schema.document import Document
from marker.v2.schema.groups import FigureGroup, ListGroup, PageGroup, \
PictureGroup, TableGroup
from marker.v2.schema.text import Line, Span

BLOCK_REGISTRY: Dict[str, Type[Block]] = {
BlockTypes.Line: Line,
BlockTypes.Span: Span,
BlockTypes.FigureGroup: FigureGroup,
BlockTypes.TableGroup: TableGroup,
BlockTypes.ListGroup: ListGroup,
BlockTypes.PictureGroup: PictureGroup,
BlockTypes.Page: PageGroup,
BlockTypes.Caption: Caption,
BlockTypes.Code: Code,
BlockTypes.Figure: Figure,
BlockTypes.Footnote: Footnote,
BlockTypes.Form: Form,
BlockTypes.Equation: Equation,
BlockTypes.Handwriting: Handwriting,
BlockTypes.TextInlineMath: InlineMath,
BlockTypes.ListItem: ListItem,
BlockTypes.PageFooter: PageFooter,
BlockTypes.PageHeader: PageHeader,
BlockTypes.Picture: Picture,
BlockTypes.SectionHeader: SectionHeader,
BlockTypes.Table: Table,
BlockTypes.Text: Text,
BlockTypes.TableOfContents: TableOfContents,
BlockTypes.Document: Document,
}

T = TypeVar('T')


def get_block_cls(block_cls: T) -> T:
return BLOCK_REGISTRY.get(block_cls.model_fields['block_type'].default, block_cls)


assert len(BLOCK_REGISTRY) == len(BlockTypes)
assert all([v.model_fields['block_type'].default == k for k, v in BLOCK_REGISTRY.items()])
6 changes: 0 additions & 6 deletions marker/v2/schema/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
from marker.v2.schema import BlockTypes
from marker.v2.schema.text.line import Line
from marker.v2.schema.text.span import Span

TEXT_BLOCK_REGISTRY = {
BlockTypes.Line: Line,
BlockTypes.Span: Span,
}

0 comments on commit b674363

Please sign in to comment.