Skip to content

Commit

Permalink
Merge plus form processor
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Dec 20, 2024
1 parent 26f68be commit 5ea06c0
Show file tree
Hide file tree
Showing 9 changed files with 425 additions and 24 deletions.
40 changes: 17 additions & 23 deletions marker/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,39 @@ def common_options(fn):
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with Gemini.")(fn)
fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with LLMs.")(fn)
return fn

def generate_config_dict(self) -> Dict[str, any]:
config = {}
output_dir = self.cli_options.get("output_dir", settings.OUTPUT_DIR)
for k, v in self.cli_options.items():
if not v:
continue

match k:
case "debug":
if v:
config["debug_pdf_images"] = True
config["debug_layout_images"] = True
config["debug_json"] = True
config["debug_data_folder"] = output_dir
config["debug_pdf_images"] = True
config["debug_layout_images"] = True
config["debug_json"] = True
config["debug_data_folder"] = output_dir
case "page_range":
if v:
config["page_range"] = parse_range_str(v)
config["page_range"] = parse_range_str(v)
case "force_ocr":
if v:
config["force_ocr"] = True
config["force_ocr"] = True
case "languages":
if v:
config["languages"] = v.split(",")
config["languages"] = v.split(",")
case "config_json":
if v:
with open(v, "r") as f:
config.update(json.load(f))
with open(v, "r") as f:
config.update(json.load(f))
case "disable_multiprocessing":
if v:
config["pdftext_workers"] = 1
config["pdftext_workers"] = 1
case "paginate_output":
if v:
config["paginate_output"] = True
config["paginate_output"] = True
case "disable_image_extraction":
if v:
config["extract_images"] = False
config["extract_images"] = False
case "high_quality":
if v:
config["high_quality"] = True
config["high_quality"] = True
return config

def get_renderer(self):
Expand Down
6 changes: 5 additions & 1 deletion marker/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

import inspect
from collections import defaultdict
Expand All @@ -17,6 +17,8 @@
from marker.processors.document_toc import DocumentTOCProcessor
from marker.processors.equation import EquationProcessor
from marker.processors.footnote import FootnoteProcessor
from marker.processors.llm.highqualityformprocessor import HighQualityFormProcessor
from marker.processors.llm.highqualitytableprocessor import HighQualityTableProcessor
from marker.processors.high_quality_text import HighQualityTextProcessor
from marker.processors.ignoretext import IgnoreTextProcessor
from marker.processors.line_numbers import LineNumbersProcessor
Expand Down Expand Up @@ -68,6 +70,8 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | No
PageHeaderProcessor,
SectionHeaderProcessor,
TableProcessor,
HighQualityTableProcessor,
HighQualityFormProcessor,
TextProcessor,
HighQualityTextProcessor,
DebugProcessor,
Expand Down
55 changes: 55 additions & 0 deletions marker/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import time

import PIL
import google.generativeai as genai
from google.ai.generativelanguage_v1beta.types import content
from google.api_core.exceptions import ResourceExhausted


class GoogleModel:
def __init__(self, api_key: str, model_name: str):
if api_key is None:
raise ValueError("Google API key is not set")

self.api_key = api_key
self.model_name = model_name
self.model = self.configure_google_model()

def configure_google_model(self):
genai.configure(api_key=self.api_key)
return genai.GenerativeModel(self.model_name)

def generate_response(
self,
prompt: str,
image: PIL.Image.Image,
response_schema: content.Schema,
max_retries: int = 3,
timeout: int = 60
):
tries = 0
while tries < max_retries:
try:
responses = self.model.generate_content(
[prompt, image],
stream=False,
generation_config={
"temperature": 0,
"response_schema": response_schema,
"response_mime_type": "application/json",
},
request_options={'timeout': timeout}
)
output = responses.candidates[0].content.parts[0].text
return json.loads(output)
except ResourceExhausted as e:
tries += 1
wait_time = tries * 3
print(f"ResourceExhausted: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{max_retries})")
time.sleep(wait_time)
except Exception as e:
print(e)
break

return {}
Empty file.
151 changes: 151 additions & 0 deletions marker/processors/llm/highqualityformprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import markdown2

from marker.llm import GoogleModel
from marker.processors import BaseProcessor
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

from google.ai.generativelanguage_v1beta.types import content
from tqdm import tqdm
from tabled.formats import markdown_format

from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup
from marker.settings import settings


class HighQualityFormProcessor(BaseProcessor):
"""
A processor for converting form blocks in a document to markdown.
Attributes:
google_api_key (str):
The Google API key to use for the Gemini model.
Default is None.
model_name (str):
The name of the Gemini model to use.
Default is "gemini-1.5-flash".
max_retries (int):
The maximum number of retries to use for the Gemini model.
Default is 3.
max_concurrency (int):
The maximum number of concurrent requests to make to the Gemini model.
Default is 3.
timeout (int):
The timeout for requests to the Gemini model.
gemini_rewriting_prompt (str):
The prompt to use for rewriting text.
Default is a string containing the Gemini rewriting prompt.
"""

block_types = (BlockTypes.Form,)
google_api_key: Optional[str] = settings.GOOGLE_API_KEY
model_name: str = "gemini-1.5-flash"
high_quality: bool = False
max_retries: int = 3
max_concurrency: int = 3
timeout: int = 60

gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
You will receive an image of a text block and a markdown representation of the form in the image.
Your task is to correct any errors in the markdown representation, and format it properly.
Values and labels should appear in markdown tables, with the labels on the left side, and values on the right. The headers should be "Labels" and "Values". Other text in the form can appear between the tables.
**Instructions:**
1. Carefully examine the provided form block image.
2. Analyze the markdown representation of the form.
3. If the markdown representation is largely correct, then write "No corrections needed."
4. If the markdown representation contains errors, generate the corrected markdown representation.
5. Output only either the corrected markdown representation or "No corrections needed."
**Example:**
Input:
```markdown
| Label 1 | Label 2 | Label 3 |
|----------|----------|----------|
| Value 1 | Value 2 | Value 3 |
```
Output:
```markdown
| Labels | Values |
|--------|--------|
| Label 1 | Value 1 |
| Label 2 | Value 2 |
| Label 3 | Value 3 |
```
**Input:**
"""

def __init__(self, config=None):
super().__init__(config)

self.model = None
if not self.high_quality:
return

self.model = GoogleModel(self.google_api_key, self.model_name)

def __call__(self, document: Document):
if not self.high_quality or self.model is None:
return

self.rewrite_blocks(document)

def rewrite_blocks(self, document: Document):
pbar = tqdm(desc="High quality form processor")
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
for future in as_completed([
executor.submit(self.process_rewriting, page, block)
for page in document.pages
for block in page.contained_blocks(document, self.block_types)
]):
future.result() # Raise exceptions if any occurred
pbar.update(1)

pbar.close()

def process_rewriting(self, page: PageGroup, block: Block):
cells = block.cells
if cells is None:
# Happens if table/form processors didn't run
return

prompt = self.gemini_rewriting_prompt + '```markdown\n`' + markdown_format(cells) + '`\n```\n'
image = self.extract_image(page, block)
response_schema = content.Schema(
type=content.Type.OBJECT,
enum=[],
required=["corrected_markdown"],
properties={
"corrected_markdown": content.Schema(
type=content.Type.STRING
)
},
)

response = self.model.generate_response(prompt, image, response_schema)

if not response or "corrected_markdown" not in response:
return

corrected_markdown = response["corrected_markdown"]

# The original table is okay
if "no corrections" in corrected_markdown.lower():
return

orig_cell_text = "".join([cell.text for cell in cells])

# Potentially a partial response
if len(corrected_markdown) < len(orig_cell_text) * .5:
return

# Convert LLM markdown to html
block.html = markdown2.markdown(corrected_markdown)

def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
page_img = page.lowres_image
image_box = image_block.polygon\
.rescale(page.polygon.size, page_img.size)\
.expand(expand, expand)
cropped = page_img.crop(image_box.bbox)
return cropped
Loading

0 comments on commit 5ea06c0

Please sign in to comment.