Skip to content

Commit

Permalink
Refactor classes
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Dec 20, 2024
1 parent 5ea06c0 commit 46dde3f
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 310 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tqdm import tqdm

from marker.builders.layout import LayoutBuilder
from marker.processors.llm import GoogleModel
from marker.providers.pdf import PdfProvider
from marker.schema import BlockTypes
from marker.schema.blocks import Block
Expand All @@ -21,7 +22,7 @@
from marker.settings import settings


class HighQualityLayoutBuilder(LayoutBuilder):
class LLMLayoutBuilder(LayoutBuilder):
"""
A builder for relabelling blocks to improve the quality of the layout.
Expand Down Expand Up @@ -69,23 +70,19 @@ class HighQualityLayoutBuilder(LayoutBuilder):
"""

def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
self.layout_model = layout_model
self.ocr_error_model = ocr_error_model
super().__init__(layout_model, ocr_error_model, config)

self.model = None
if self.google_api_key is None:
raise ValueError("Google API key is not set")

genai.configure(api_key=self.google_api_key)
self.model = genai.GenerativeModel(self.model_name)
self.model = GoogleModel(self.google_api_key, self.model_name)

def __call__(self, document: Document, provider: PdfProvider):
super().__call__(document, provider)

self.relabel_blocks(document)
try:
self.relabel_blocks(document)
except Exception as e:
print(f"Error relabelling blocks: {e}")

def relabel_blocks(self, document: Document):
pbar = tqdm(desc="High quality layout relabelling")
pbar = tqdm(desc="LLM layout relabelling")
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
futures = []
for page in document.pages:
Expand Down Expand Up @@ -118,12 +115,12 @@ def process_block_relabelling(self, page: PageGroup, block: Block):
},
)

response = self.generate(prompt, image, response_schema)
response = self.model.generate_response(prompt, image, response_schema)
generated_label = None
if response and "label" in response:
generated_label = response["label"]

if generated_label and generated_label != str(block.block_type):
if generated_label and generated_label != str(block.block_type) and generated_label in BlockTypes:
generated_block_class = get_block_class(BlockTypes[generated_label])
generated_block = generated_block_class(
polygon=block.polygon,
Expand All @@ -138,32 +135,4 @@ def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.0
.rescale(page.polygon.size, page_img.size)\
.expand(expand, expand)
cropped = page_img.crop(image_box.bbox)
return cropped

def generate(self, prompt: str, image: PIL.Image.Image, response_schema: content.Schema):
tries = 0
while tries < self.max_retries:
try:
responses = self.model.generate_content(
[prompt, image],
stream=False,
generation_config={
"temperature": 0,
"response_schema": response_schema,
"response_mime_type": "application/json",
},
request_options={'timeout': self.timeout}
)
output = responses.candidates[0].content.parts[0].text
return json.loads(output)

except ResourceExhausted as e:
tries += 1
wait_time = tries * 2
print(f"ResourceExhausted: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{self.max_retries})")
time.sleep(wait_time)
except Exception as e:
print(e)
break

return {}
return cropped
6 changes: 3 additions & 3 deletions marker/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def common_options(fn):
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with LLMs.")(fn)
fn = click.option("--use_llm", is_flag=True, default=False, help="Enable higher quality processing with LLMs.")(fn)
return fn

def generate_config_dict(self) -> Dict[str, any]:
Expand Down Expand Up @@ -65,8 +65,8 @@ def generate_config_dict(self) -> Dict[str, any]:
config["paginate_output"] = True
case "disable_image_extraction":
config["extract_images"] = False
case "high_quality":
config["high_quality"] = True
case "use_llm":
config["use_llm"] = True
return config

def get_renderer(self):
Expand Down
20 changes: 10 additions & 10 deletions marker/converters/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Dict, List, Type

from marker.builders.document import DocumentBuilder
from marker.builders.high_quality_layout import HighQualityLayoutBuilder
from marker.builders.llm_layout import LLMLayoutBuilder
from marker.builders.layout import LayoutBuilder
from marker.builders.ocr import OcrBuilder
from marker.builders.structure import StructureBuilder
Expand All @@ -17,9 +17,9 @@
from marker.processors.document_toc import DocumentTOCProcessor
from marker.processors.equation import EquationProcessor
from marker.processors.footnote import FootnoteProcessor
from marker.processors.llm.highqualityformprocessor import HighQualityFormProcessor
from marker.processors.llm.highqualitytableprocessor import HighQualityTableProcessor
from marker.processors.high_quality_text import HighQualityTextProcessor
from marker.processors.llm.llm_form import LLMFormProcessor
from marker.processors.llm.llm_table import LLMTableProcessor
from marker.processors.llm.llm_text import LLMTextProcessor
from marker.processors.ignoretext import IgnoreTextProcessor
from marker.processors.line_numbers import LineNumbersProcessor
from marker.processors.list import ListProcessor
Expand Down Expand Up @@ -47,7 +47,7 @@ class PdfConverter(BaseConverter):
instead of the defaults.
"""
override_map: Dict[BlockTypes, Type[Block]] = defaultdict()
high_quality: bool = False
use_llm: bool = False

def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | None = None, renderer: str | None = None, config=None):
super().__init__(config)
Expand All @@ -70,10 +70,10 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | No
PageHeaderProcessor,
SectionHeaderProcessor,
TableProcessor,
HighQualityTableProcessor,
HighQualityFormProcessor,
LLMTableProcessor,
LLMFormProcessor,
TextProcessor,
HighQualityTextProcessor,
LLMTextProcessor,
DebugProcessor,
]

Expand All @@ -87,8 +87,8 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | No
self.renderer = renderer

self.layout_builder_class = LayoutBuilder
if self.high_quality:
self.layout_builder_class = HighQualityLayoutBuilder
if self.use_llm:
self.layout_builder_class = LLMLayoutBuilder

def resolve_dependencies(self, cls):
init_signature = inspect.signature(cls.__init__)
Expand Down
84 changes: 84 additions & 0 deletions marker/processors/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

from datasets import tqdm

from marker.processors import BaseProcessor
from marker.processors.llm.utils import GoogleModel
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups import PageGroup
from marker.settings import settings


class BaseLLMProcessor(BaseProcessor):
"""
A processor for using LLMs to convert blocks.
Attributes:
google_api_key (str):
The Google API key to use for the Gemini model.
Default is None.
model_name (str):
The name of the Gemini model to use.
Default is "gemini-1.5-flash".
max_retries (int):
The maximum number of retries to use for the Gemini model.
Default is 3.
max_concurrency (int):
The maximum number of concurrent requests to make to the Gemini model.
Default is 3.
timeout (int):
The timeout for requests to the Gemini model.
gemini_rewriting_prompt (str):
The prompt to use for rewriting text.
Default is a string containing the Gemini rewriting prompt.
"""

google_api_key: Optional[str] = settings.GOOGLE_API_KEY
model_name: str = "gemini-1.5-flash"
use_llm: bool = False
max_retries: int = 3
max_concurrency: int = 3
timeout: int = 60
image_expansion_ratio: float = 0.01
gemini_rewriting_prompt = None
block_types = None

def __init__(self, config=None):
super().__init__(config)

self.model = None
if not self.use_llm:
return

self.model = GoogleModel(self.google_api_key, self.model_name)

def __call__(self, document: Document):
if not self.use_llm or self.model is None:
return

self.rewrite_blocks(document)

def process_rewriting(self, document: Document, page: PageGroup, block: Block):
raise NotImplementedError()

def rewrite_blocks(self, document: Document):
pbar = tqdm(desc=f"{self.__class__.__name__} running")
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
for future in as_completed([
executor.submit(self.process_rewriting, document, page, block)
for page in document.pages
for block in page.contained_blocks(document, self.block_types)
]):
future.result() # Raise exceptions if any occurred
pbar.update(1)

pbar.close()

def extract_image(self, page: PageGroup, image_block: Block):
page_img = page.lowres_image
image_box = image_block.polygon\
.rescale(page.polygon.size, page_img.size)\
.expand(self.image_expansion_ratio, self.image_expansion_ratio)
cropped = page_img.crop(image_box.bbox)
return cropped
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import markdown2

from marker.llm import GoogleModel
from marker.processors import BaseProcessor
from marker.processors.llm import BaseLLMProcessor
from marker.processors.llm.utils import GoogleModel
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

from google.ai.generativelanguage_v1beta.types import content
from tqdm import tqdm
Expand All @@ -13,40 +12,10 @@
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup
from marker.settings import settings


class HighQualityFormProcessor(BaseProcessor):
"""
A processor for converting form blocks in a document to markdown.
Attributes:
google_api_key (str):
The Google API key to use for the Gemini model.
Default is None.
model_name (str):
The name of the Gemini model to use.
Default is "gemini-1.5-flash".
max_retries (int):
The maximum number of retries to use for the Gemini model.
Default is 3.
max_concurrency (int):
The maximum number of concurrent requests to make to the Gemini model.
Default is 3.
timeout (int):
The timeout for requests to the Gemini model.
gemini_rewriting_prompt (str):
The prompt to use for rewriting text.
Default is a string containing the Gemini rewriting prompt.
"""

class LLMFormProcessor(BaseLLMProcessor):
block_types = (BlockTypes.Form,)
google_api_key: Optional[str] = settings.GOOGLE_API_KEY
model_name: str = "gemini-1.5-flash"
high_quality: bool = False
max_retries: int = 3
max_concurrency: int = 3
timeout: int = 60

gemini_rewriting_prompt = """You are a text correction expert specializing in accurately reproducing text from images.
You will receive an image of a text block and a markdown representation of the form in the image.
Your task is to correct any errors in the markdown representation, and format it properly.
Expand Down Expand Up @@ -75,35 +44,7 @@ class HighQualityFormProcessor(BaseProcessor):
**Input:**
"""

def __init__(self, config=None):
super().__init__(config)

self.model = None
if not self.high_quality:
return

self.model = GoogleModel(self.google_api_key, self.model_name)

def __call__(self, document: Document):
if not self.high_quality or self.model is None:
return

self.rewrite_blocks(document)

def rewrite_blocks(self, document: Document):
pbar = tqdm(desc="High quality form processor")
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
for future in as_completed([
executor.submit(self.process_rewriting, page, block)
for page in document.pages
for block in page.contained_blocks(document, self.block_types)
]):
future.result() # Raise exceptions if any occurred
pbar.update(1)

pbar.close()

def process_rewriting(self, page: PageGroup, block: Block):
def process_rewriting(self, document: Document, page: PageGroup, block: Block):
cells = block.cells
if cells is None:
# Happens if table/form processors didn't run
Expand Down Expand Up @@ -140,12 +81,4 @@ def process_rewriting(self, page: PageGroup, block: Block):
return

# Convert LLM markdown to html
block.html = markdown2.markdown(corrected_markdown)

def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
page_img = page.lowres_image
image_box = image_block.polygon\
.rescale(page.polygon.size, page_img.size)\
.expand(expand, expand)
cropped = page_img.crop(image_box.bbox)
return cropped
block.html = markdown2.markdown(corrected_markdown)
Loading

0 comments on commit 46dde3f

Please sign in to comment.