Skip to content

Commit

Permalink
Merge pull request #454 from VikParuchuri/vik_dev
Browse files Browse the repository at this point in the history
LLM based image captioning
  • Loading branch information
VikParuchuri authored Jan 2, 2025
2 parents 2353f82 + ba61808 commit 0171bf2
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 7 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,13 @@ Options:
- `--output_dir PATH`: Directory where output files will be saved. Defaults to the value specified in settings.OUTPUT_DIR.
- `--output_format [markdown|json|html]`: Specify the format for the output results.
- `--use_llm`: Uses an LLM to improve accuracy. You must set your Gemini API key using the `GOOGLE_API_KEY` env var.
- `--disable_image_extraction`: Don't extract images from the PDF. If you also specify `--use_llm`, then images will be replaced with a description.
- `--page_range TEXT`: Specify which pages to process. Accepts comma-separated page numbers and ranges. Example: `--page_range "0,5-10,20"` will process pages 0, 5 through 10, and page 20.
- `--force_ocr`: Force OCR processing on the entire document, even for pages that might contain extractable text.
- `--debug`: Enable debug mode for additional logging and diagnostic information.
- `--processors TEXT`: Override the default processors by providing their full module paths, separated by commas. Example: `--processors "module1.processor1,module2.processor2"`
- `--config_json PATH`: Path to a JSON configuration file containing additional settings.
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "eng,fra,deu"` for English, French, and German.
- `--languages TEXT`: Optionally specify which languages to use for OCR processing. Accepts a comma-separated list. Example: `--languages "en,fr,de"` for English, French, and German.
- `config --help`: List all available builders, processors, and converters, and their associated configuration. These values can be used to build a JSON configuration file for additional tweaking of marker defaults.

The list of supported languages for surya OCR is [here](https://github.com/VikParuchuri/surya/blob/master/surya/languages.py). If you don't need OCR, marker can work with any language.
Expand Down
6 changes: 3 additions & 3 deletions marker/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import os

from marker.processors.llm.llm_complex import LLMComplexRegionProcessor

os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

import inspect
Expand All @@ -14,6 +11,7 @@
from marker.builders.ocr import OcrBuilder
from marker.builders.structure import StructureBuilder
from marker.converters import BaseConverter
from marker.processors.llm.llm_complex import LLMComplexRegionProcessor
from marker.processors.blockquote import BlockquoteProcessor
from marker.processors.code import CodeProcessor
from marker.processors.debug import DebugProcessor
Expand All @@ -23,6 +21,7 @@
from marker.processors.llm.llm_form import LLMFormProcessor
from marker.processors.llm.llm_table import LLMTableProcessor
from marker.processors.llm.llm_text import LLMTextProcessor
from marker.processors.llm.llm_image_description import LLMImageDescriptionProcessor
from marker.processors.ignoretext import IgnoreTextProcessor
from marker.processors.line_numbers import LineNumbersProcessor
from marker.processors.list import ListProcessor
Expand Down Expand Up @@ -78,6 +77,7 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | No
TextProcessor,
LLMTextProcessor,
LLMComplexRegionProcessor,
LLMImageDescriptionProcessor,
DebugProcessor,
]

Expand Down
62 changes: 62 additions & 0 deletions marker/processors/llm/llm_image_description.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from marker.processors.llm import BaseLLMProcessor

from google.ai.generativelanguage_v1beta.types import content

from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup


class LLMImageDescriptionProcessor(BaseLLMProcessor):
block_types = (BlockTypes.Picture, BlockTypes.Figure,)
extract_images: bool = True
image_description_prompt = """You are a document analysis expert who specializes in creating text descriptions for images.
You will receive an image of a picture or figure. Your job will be to create a short description of the image.
**Instructions:**
1. Carefully examine the provided image.
2. Analyze any text that was extracted from within the image.
3. Output a 3-4 sentence description of the image. Make sure there is enough specific detail to accurately describe the image. If there are numbers included, try to be specific.
**Example:**
Input:
```text
"Fruit Preference Survey"
20, 15, 10
Apples, Bananas, Oranges
```
Output:
In this figure, a bar chart titled "Fruit Preference Survey" is showing the number of people who prefer different types of fruits. The x-axis shows the types of fruits, and the y-axis shows the number of people. The bar chart shows that most people prefer apples, followed by bananas and oranges. 20 people prefer apples, 15 people prefer bananas, and 10 people prefer oranges.
**Input:**
"""

def process_rewriting(self, document: Document, page: PageGroup, block: Block):
if self.extract_images:
# We will only run this processor if we're not extracting images
# Since this processor replaces images with descriptions
return

prompt = self.image_description_prompt + '```text\n`' + block.raw_text(document) + '`\n```\n'
image = self.extract_image(page, block)
response_schema = content.Schema(
type=content.Type.OBJECT,
enum=[],
required=["image_description"],
properties={
"image_description": content.Schema(
type=content.Type.STRING
)
},
)

response = self.model.generate_response(prompt, image, block, response_schema)

if not response or "image_description" not in response:
block.update_metadata(llm_error_count=1)
return

image_description = response["image_description"]
if len(image_description) < 10:
block.update_metadata(llm_error_count=1)
return

block.description = image_description
4 changes: 3 additions & 1 deletion marker/renderers/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def extract_html(self, document, document_output, level=0):
for ref in content_refs:
src = ref.get('src')
sub_images = {}
content = ""
for item in document_output.children:
if item.id == src:
content, sub_images_ = self.extract_html(document, item, level + 1)
Expand All @@ -61,7 +62,8 @@ def extract_html(self, document, document_output, level=0):
images[image_name] = image
ref.replace_with(BeautifulSoup(f"<p><img src='{image_name}'></p>", 'html.parser'))
else:
ref.replace_with('')
# This will be the image description if using llm mode, or empty if not
ref.replace_with(BeautifulSoup(f"{content}", 'html.parser'))
elif ref_block_id.block_type in self.page_blocks:
images.update(sub_images)
if self.paginate_output:
Expand Down
6 changes: 5 additions & 1 deletion marker/schema/blocks/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

class Figure(Block):
block_type: BlockTypes = BlockTypes.Figure
description: str | None = None

def assemble_html(self, child_blocks, parent_structure):
return f"<p>Image {self.id}</p>"
if self.description:
return f"<p role='img' data-original-image-id='{self.id}'>Image {self.id} description: {self.description}</p>"
else:
return ""
6 changes: 5 additions & 1 deletion marker/schema/blocks/picture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

class Picture(Block):
block_type: BlockTypes = BlockTypes.Picture
description: str | None = None

def assemble_html(self, child_blocks, parent_structure):
return f"<p>Image {self.id}</p>"
if self.description:
return f"<p role='img' data-original-image-id='{self.id}'>Image {self.id} description: {self.description}</p>"
else:
return ""
31 changes: 31 additions & 0 deletions tests/processors/test_llm_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import pytest

from marker.processors.llm.llm_form import LLMFormProcessor
from marker.processors.llm.llm_image_description import LLMImageDescriptionProcessor
from marker.processors.llm.llm_table import LLMTableProcessor
from marker.processors.llm.llm_text import LLMTextProcessor
from marker.processors.table import TableProcessor
from marker.renderers.markdown import MarkdownRenderer
from marker.schema import BlockTypes

@pytest.mark.filename("form_1040.pdf")
Expand Down Expand Up @@ -108,3 +110,32 @@ def test_llm_text_processor(pdf_document, mocker):
contained_spans = text_lines[0].contained_blocks(pdf_document, (BlockTypes.Span,))
assert contained_spans[0].text == "Text\n" # Newline inserted at end of line
assert contained_spans[0].formats == ["italic"]


@pytest.mark.filename("A17_FlightPlan.pdf")
@pytest.mark.config({"page_range": [0]})
def test_llm_caption_processor_disabled(pdf_document):
processor = LLMImageDescriptionProcessor({"use_llm": True, "google_api_key": "test"})
processor(pdf_document)

contained_pictures = pdf_document.contained_blocks((BlockTypes.Picture, BlockTypes.Figure))
assert all(picture.description is None for picture in contained_pictures)

@pytest.mark.filename("A17_FlightPlan.pdf")
@pytest.mark.config({"page_range": [0]})
def test_llm_caption_processor(pdf_document, mocker):
description = "This is an image description."
mock_cls = Mock()
mock_cls.return_value.generate_response.return_value = {"image_description": description}
mocker.patch("marker.processors.llm.GoogleModel", mock_cls)
processor = LLMImageDescriptionProcessor({"use_llm": True, "google_api_key": "test", "extract_images": False})
processor(pdf_document)

contained_pictures = pdf_document.contained_blocks((BlockTypes.Picture, BlockTypes.Figure))
assert all(picture.description == description for picture in contained_pictures)

# Ensure the rendering includes the description
renderer = MarkdownRenderer({"extract_images": False})
md = renderer(pdf_document).markdown

assert description in md
23 changes: 23 additions & 0 deletions tests/renderers/test_extract_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from marker.renderers.markdown import MarkdownRenderer


@pytest.mark.config({"page_range": [0]})
@pytest.mark.filename("A17_FlightPlan.pdf")
def test_disable_extract_images(pdf_document):
renderer = MarkdownRenderer({"extract_images": False})
md = renderer(pdf_document).markdown

# Verify markdown
assert len(md) == 0


@pytest.mark.config({"page_range": [0]})
@pytest.mark.filename("A17_FlightPlan.pdf")
def test_extract_images(pdf_document):
renderer = MarkdownRenderer()
md = renderer(pdf_document).markdown

# Verify markdown
assert "jpeg" in md

0 comments on commit 0171bf2

Please sign in to comment.