diff --git a/tests/test_overriding.py b/tests/test_overriding.py index b70bb1c7..01c4e588 100644 --- a/tests/test_overriding.py +++ b/tests/test_overriding.py @@ -1,14 +1,24 @@ +import multiprocessing as mp + import pytest +from marker.v2.providers.pdf import PdfProvider from marker.v2.schema import BlockTypes -from marker.v2.schema.document import Document from marker.v2.schema.blocks import SectionHeader +from marker.v2.schema.document import Document +from marker.v2.schema.registry import register_block_class +from marker.v2.schema.text import Line +from tests.utils import setup_pdf_provider class NewSectionHeader(SectionHeader): pass +class NewLine(Line): + pass + + @pytest.mark.config({ "page_range": [0], "override_map": {BlockTypes.SectionHeader: NewSectionHeader} @@ -16,3 +26,24 @@ class NewSectionHeader(SectionHeader): def test_overriding(pdf_document: Document): assert pdf_document.pages[0]\ .get_block(pdf_document.pages[0].structure[0]).__class__ == NewSectionHeader + + +def get_lines(pdf: str, config=None): + provider: PdfProvider = setup_pdf_provider(pdf, config) + return provider.get_page_lines(0) + + +def test_overriding_mp(): + config = { + "page_range": [0], + "override_map": {BlockTypes.Line: NewLine} + } + + for block_type, block_cls in config["override_map"].items(): + register_block_class(block_type, block_cls) + + pdf_list = ["adversarial.pdf", "adversarial_rot.pdf"] + + with mp.Pool(processes=2) as pool: + results = pool.starmap(get_lines, [(pdf, config) for pdf in pdf_list]) + assert all([r[0].__class__ == NewLine for r in results]) diff --git a/tests/utils.py b/tests/utils.py index b2e01db5..1f7a2911 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,10 +9,10 @@ from marker.v2.schema.document import Document -def setup_pdf_document( +def setup_pdf_provider( filename='adversarial.pdf', config=None, -) -> Document: +) -> PdfProvider: dataset = datasets.load_dataset("datalab-to/pdfs", split="train") idx = dataset['filename'].index(filename) @@ -20,11 +20,19 @@ def setup_pdf_document( temp_pdf.write(dataset['pdf'][idx]) temp_pdf.flush() + provider = PdfProvider(temp_pdf.name, config) + return provider + + +def setup_pdf_document( + filename='adversarial.pdf', + config=None, +) -> Document: layout_model = setup_layout_model() recognition_model = setup_recognition_model() detection_model = setup_detection_model() - provider = PdfProvider(temp_pdf.name, config) + provider = setup_pdf_provider(filename, config) layout_builder = LayoutBuilder(layout_model, config) ocr_builder = OcrBuilder(detection_model, recognition_model, config) builder = DocumentBuilder(config)