Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #435

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open

Dev #435

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4b0dc70
Add high quality inlinemath processor
iammosespaulr Dec 11, 2024
41c2864
save state
iammosespaulr Dec 12, 2024
6df9ca1
Merge remote-tracking branch 'origin/master' into highquality-processors
iammosespaulr Dec 12, 2024
d93848b
env var for google api key
iammosespaulr Dec 13, 2024
efc612c
cleanup and speedup
iammosespaulr Dec 13, 2024
3a4c966
add top_k
iammosespaulr Dec 13, 2024
922bcab
Add New OCR Heuristics Model
tarun-menta Dec 15, 2024
8e84595
Update error detection logic
tarun-menta Dec 16, 2024
0c6c117
add high quality builder
iammosespaulr Dec 16, 2024
08052b6
subclass layout builder and add a high quality text processor
iammosespaulr Dec 17, 2024
9565572
Fix table error
VikParuchuri Dec 19, 2024
7c62bc1
Merge pull request #427 from tarun-menta/ocr-model-heuristics
VikParuchuri Dec 19, 2024
232194b
integrate changes and increment surya version
iammosespaulr Dec 19, 2024
f367db0
Add tests
VikParuchuri Dec 19, 2024
4886d43
Merge pull request #434 from VikParuchuri/vik_dev
VikParuchuri Dec 19, 2024
1a28996
missing deps
iammosespaulr Dec 19, 2024
0bdc447
add retries, concurrency limits and timeouts
iammosespaulr Dec 19, 2024
a5de368
Fix test
VikParuchuri Dec 20, 2024
e48364c
parse out and recreate spans in the high quality text processor and h…
iammosespaulr Dec 20, 2024
32f8095
render block mode the same way
iammosespaulr Dec 20, 2024
7542422
configurable math delimiters
iammosespaulr Dec 20, 2024
796966c
Merge remote-tracking branch 'origin/dev' into highquality-processors
iammosespaulr Dec 20, 2024
4d4c469
fix highQualityLayoutBuilder
iammosespaulr Dec 20, 2024
d2c32af
fix tests
iammosespaulr Dec 20, 2024
26f68be
Merge pull request #429 from VikParuchuri/highquality-processors
VikParuchuri Dec 20, 2024
5ea06c0
Merge plus form processor
VikParuchuri Dec 20, 2024
46dde3f
Refactor classes
VikParuchuri Dec 20, 2024
0d1c9ff
Add tests for llm processors
VikParuchuri Dec 20, 2024
4413bf3
Merge pull request #438 from VikParuchuri/vik_dev
VikParuchuri Dec 20, 2024
39fdce5
Add documentation for LLM mode
VikParuchuri Dec 20, 2024
375221c
Fix tests
VikParuchuri Dec 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Marker converts PDFs to markdown, JSON, and HTML quickly and accurately.
- Extracts and saves images along with the markdown
- Converts equations to latex
- Easily extensible with your own formatting and logic
- Optionally boost accuracy with an LLM
- Works on GPU, CPU, or MPS

## How it works
Expand Down Expand Up @@ -99,10 +100,11 @@ marker_single /path/to/file.pdf

Options:
- `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
- `--debug`: Enable debug mode for additional logging and diagnostic information.
- `--output_format [markdown|json|html]`: Specify the format for the output results.
- `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
- `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
- `--force_ocr`: Force OCR processing on the entire document, even for pages that might contain extractable text.
- `--debug`: Enable debug mode for additional logging and diagnostic information.
- `--processors TEXT`: Override the default processors by providing their full module paths, separated by commas. Example: `--processors "module1.processor1,module2.processor2"`
- `--config_json PATH`: Path to a JSON configuration file containing additional settings.
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "eng,fra,deu"` for English, French, and German.
Expand All @@ -127,7 +129,6 @@ NUM_DEVICES=4 NUM_WORKERS=15 marker_chunk_convert ../pdf_in ../md_out

- `NUM_DEVICES` is the number of GPUs to use. Should be `2` or greater.
- `NUM_WORKERS` is the number of parallel processes to run on each GPU.
-

## Use from python

Expand All @@ -149,7 +150,7 @@ text, _, images = text_from_rendered(rendered)

### Custom configuration

You can also pass configuration using the `ConfigParser`:
You can pass configuration using the `ConfigParser`:

```python
from marker.converters.pdf import PdfConverter
Expand All @@ -171,6 +172,26 @@ converter = PdfConverter(
rendered = converter("FILEPATH")
```

### Extract blocks

Each document consists of one or more pages. Pages contain blocks, which can themselves contain other blocks. It's possible to programatically manipulate these blocks.

Here's an example of extracting all forms from a document:

```python
from marker.converters.pdf import PdfConverter
from marker.models import create_model_dict
from marker.schema import BlockTypes

converter = PdfConverter(
artifact_dict=create_model_dict(),
)
document = converter.build_document("FILEPATH")
forms = document.contained_blocks((BlockTypes.Form,))
```

Look at the processors for more examples of extracting and manipulating blocks.

# Output Formats

## Markdown
Expand Down Expand Up @@ -312,6 +333,7 @@ Note that this is not a very robust API, and is only intended for small-scale us

There are some settings that you may find useful if things aren't working the way you expect:

- If you have issues with accuracy, try setting `--use_llm` to use an LLM to improve quality. You must set `GOOGLE_API_KEY` to a Gemini API key for this to work.
- Make sure to set `force_ocr` if you see garbled text - this will re-OCR the document.
- `TORCH_DEVICE` - set this to force marker to use a given torch device for inference.
- If you're getting out of memory errors, decrease worker count. You can also try splitting up long PDFs into multiple files.
Expand Down
2 changes: 2 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
os.environ["IN_STREAMLIT"] = "true" # Avoid multiprocessing inside surya

Expand Down
5 changes: 3 additions & 2 deletions convert_single.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
os.environ["GRPC_VERBOSITY"] = "ERROR"
os.environ["GLOG_minloglevel"] = "2"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

import time

import click

from marker.config.parser import ConfigParser
Expand Down
48 changes: 42 additions & 6 deletions marker/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from surya.schema import LayoutResult
from surya.model.layout.encoderdecoder import SuryaLayoutModel

from surya.ocr_error import batch_ocr_error_detection
from surya.schema import OCRErrorDetectionResult
from surya.model.ocr_error.model import DistilBertForSequenceClassification

from marker.settings import settings
from marker.builders import BaseBuilder
from marker.providers import ProviderOutput, ProviderPageLines
Expand Down Expand Up @@ -37,15 +41,21 @@ class LayoutBuilder(BaseBuilder):
document_ocr_threshold (float):
The minimum ratio of pages that must pass the layout coverage check
to avoid OCR. Default is 0.8.

error_model_segment_length (int):
The maximum number of characters to send to the OCR error model.
Default is 1024.
"""
batch_size = None
layout_coverage_min_lines = 1
layout_coverage_threshold = .1
document_ocr_threshold = .8
error_model_segment_length = 512
excluded_for_coverage = (BlockTypes.Figure, BlockTypes.Picture, BlockTypes.Table, BlockTypes.FigureGroup, BlockTypes.TableGroup, BlockTypes.PictureGroup)

def __init__(self, layout_model: SuryaLayoutModel, config=None):
def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
self.layout_model = layout_model
self.ocr_error_model = ocr_error_model

super().__init__(config)

Expand All @@ -71,15 +81,41 @@ def surya_layout(self, pages: List[PageGroup]) -> List[LayoutResult]:
)
return layout_results

def surya_ocr_error_detection(self, pages:List[PageGroup], provider_page_lines: ProviderPageLines) -> OCRErrorDetectionResult:
page_texts = []
for document_page in pages:
page_text = ''
provider_lines = provider_page_lines.get(document_page.page_id, [])
for line in provider_lines:
page_text += ' '.join([s.text for s in line.spans])

# Sample text from the middle
if len(page_text) > 0:
page_text_middle = len(page_text) // 2
page_text_start = max(0, page_text_middle - self.error_model_segment_length // 2)
page_text_end = page_text_start + self.error_model_segment_length
page_text = page_text[page_text_start:page_text_end]

page_texts.append(page_text)

ocr_error_detection_results = batch_ocr_error_detection(
page_texts,
self.ocr_error_model,
self.ocr_error_model.tokenizer,
batch_size=int(self.get_batch_size()) #TODO Better Multiplier
)
return ocr_error_detection_results

def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[LayoutResult]):
for page, layout_result in zip(pages, layout_results):
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
provider_page_size = page.polygon.size
page.layout_sliced = layout_result.sliced # This indicates if the page was sliced by the layout model
page.layout_sliced = layout_result.sliced # This indicates if the page was sliced by the layout model
for bbox in sorted(layout_result.bboxes, key=lambda x: x.position):
block_cls = get_block_class(BlockTypes[bbox.label])
layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon))
layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size)
layout_block.top_k = {BlockTypes[label]: prob for (label, prob) in bbox.top_k.items()}
page.add_structure(layout_block)

# Ensure page has non-empty structure
Expand All @@ -91,16 +127,17 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou
page.children = []

def merge_blocks(self, document_pages: List[PageGroup], provider_page_lines: ProviderPageLines):
ocr_error_detection_labels = self.surya_ocr_error_detection(document_pages, provider_page_lines).labels

good_pages = []
for document_page in document_pages:
for (document_page, ocr_error_detection_label) in zip(document_pages, ocr_error_detection_labels):
provider_lines = provider_page_lines.get(document_page.page_id, [])
good_pages.append(self.check_layout_coverage(document_page, provider_lines))
good_pages.append(self.check_layout_coverage(document_page, provider_lines) and (ocr_error_detection_label != "bad"))

ocr_document = sum(good_pages) / len(good_pages) < self.document_ocr_threshold
for idx, document_page in enumerate(document_pages):
provider_lines = provider_page_lines.get(document_page.page_id, [])
needs_ocr = not good_pages[idx]

if needs_ocr and ocr_document:
document_page.text_extraction_method = "surya"
continue
Expand Down Expand Up @@ -141,4 +178,3 @@ def check_layout_coverage(
if not text_okay and (total_blocks == 1 and large_text_blocks == 1):
text_okay = True
return text_okay

138 changes: 138 additions & 0 deletions marker/builders/llm_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

import google.generativeai as genai
import PIL
from google.ai.generativelanguage_v1beta.types import content
from google.api_core.exceptions import ResourceExhausted
from surya.model.layout.encoderdecoder import SuryaLayoutModel
from surya.model.ocr_error.model import DistilBertForSequenceClassification
from tqdm import tqdm

from marker.builders.layout import LayoutBuilder
from marker.processors.llm import GoogleModel
from marker.providers.pdf import PdfProvider
from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup
from marker.schema.registry import get_block_class
from marker.settings import settings


class LLMLayoutBuilder(LayoutBuilder):
"""
A builder for relabelling blocks to improve the quality of the layout.

Attributes:
google_api_key (str):
The Google API key to use for the Gemini model.
Default is None.
confidence_threshold (float):
The confidence threshold to use for relabeling.
Default is 0.8.
model_name (str):
The name of the Gemini model to use.
Default is "gemini-1.5-flash".
max_retries (int):
The maximum number of retries to use for the Gemini model.
Default is 3.
max_concurrency (int):
The maximum number of concurrent requests to make to the Gemini model.
Default is 3.
timeout (int):
The timeout for requests to the Gemini model.
Default is 60 seconds.
gemini_relabelling_prompt (str):
The prompt to use for relabelling blocks.
Default is a string containing the Gemini relabelling prompt.
"""

google_api_key: Optional[str] = settings.GOOGLE_API_KEY
confidence_threshold: float = 0.75
model_name: str = "gemini-1.5-flash"
max_retries: int = 3
max_concurrency: int = 3
timeout: int = 60

gemini_relabelling_prompt = """You are a layout expert specializing in document analysis.
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
You will be provided with an image of a layout block and the top k predictions from the current model, along with their confidence scores.
Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
Do not invent any new labels.
Carefully examine the image and consider the provided predictions.
Choose the label you believe is the most accurate representation of the layout block.

Here are the top k predictions from the model followed by the image:

"""

def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
super().__init__(layout_model, ocr_error_model, config)

self.model = GoogleModel(self.google_api_key, self.model_name)

def __call__(self, document: Document, provider: PdfProvider):
super().__call__(document, provider)
try:
self.relabel_blocks(document)
except Exception as e:
print(f"Error relabelling blocks: {e}")

def relabel_blocks(self, document: Document):
pbar = tqdm(desc="LLM layout relabelling")
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
futures = []
for page in document.pages:
for block_id in page.structure:
block = page.get_block(block_id)
if block.top_k:
confidence = block.top_k.get(block.block_type)
if confidence < self.confidence_threshold:
futures.append(executor.submit(self.process_block_relabelling, page, block))

for future in as_completed(futures):
future.result() # Raise exceptions if any occurred
pbar.update(1)

pbar.close()

def process_block_relabelling(self, page: PageGroup, block: Block):
topk = {str(k): round(v, 3) for k, v in block.top_k.items()}

prompt = self.gemini_relabelling_prompt + '```json' + json.dumps(topk) + '```\n'
image = self.extract_image(page, block)
response_schema = content.Schema(
type=content.Type.OBJECT,
enum=[],
required=["label"],
properties={
"label": content.Schema(
type=content.Type.STRING,
),
},
)

response = self.model.generate_response(prompt, image, response_schema)
generated_label = None
if response and "label" in response:
generated_label = response["label"]

if generated_label and generated_label != str(block.block_type) and generated_label in BlockTypes:
generated_block_class = get_block_class(BlockTypes[generated_label])
generated_block = generated_block_class(
polygon=block.polygon,
page_id=block.page_id,
structure=block.structure,
)
page.replace_block(block, generated_block)

def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
page_img = page.lowres_image
image_box = image_block.polygon\
.rescale(page.polygon.size, page_img.size)\
.expand(expand, expand)
cropped = page_img.crop(image_box.bbox)
return cropped
38 changes: 18 additions & 20 deletions marker/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,41 +34,39 @@ def common_options(fn):
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
return fn

def generate_config_dict(self) -> Dict[str, any]:
config = {}
output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
for k, v in self.cli_options.items():
if not v:
continue

match k:
case "debug":
if v:
config["debug_pdf_images"] = True
config["debug_layout_images"] = True
config["debug_json"] = True
config["debug_data_folder"] = output_dir
config["debug_pdf_images"] = True
config["debug_layout_images"] = True
config["debug_json"] = True
config["debug_data_folder"] = output_dir
case "page_range":
if v:
config["page_range"] = parse_range_str(v)
config["page_range"] = parse_range_str(v)
case "force_ocr":
if v:
config["force_ocr"] = True
config["force_ocr"] = True
case "languages":
if v:
config["languages"] = v.split(",")
config["languages"] = v.split(",")
case "config_json":
if v:
with open(v, "r") as f:
config.update(json.load(f))
with open(v, "r") as f:
config.update(json.load(f))
case "disable_multiprocessing":
if v:
config["pdftext_workers"] = 1
config["pdftext_workers"] = 1
case "paginate_output":
if v:
config["paginate_output"] = True
config["paginate_output"] = True
case "disable_image_extraction":
if v:
config["extract_images"] = False
config["extract_images"] = False
case "use_llm":
config["use_llm"] = True
return config

def get_renderer(self):
Expand Down
Loading
Loading