Skip to content

Commit

Permalink
Merge pull request #278 from enoch3712/253-merge-action---allow-multi…
Browse files Browse the repository at this point in the history
…ple-sources-to-be-added

extractor extract multiple sources added with tests
  • Loading branch information
enoch3712 authored Feb 21, 2025
2 parents c764896 + 68ed539 commit fd24ba3
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 18 deletions.
77 changes: 63 additions & 14 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,28 @@ def set_skip_loading(self, skip: bool = True) -> None:

def extract(
self,
source: Union[str, IO, list],
response_model: type[BaseModel],
source: Union[str, IO, List[Union[str, IO]]],
response_model: Type[BaseModel],
vision: bool = False,
content: Optional[str] = None,
completion_strategy: Optional[CompletionStrategy] = CompletionStrategy.FORBIDDEN
) -> Any:
"""
Extract information from the provided source.
Extract information from one or more sources.
If source is a list, it loads each one, converts each to a universal format, and
merges them as if they were a single document. This merged content is then passed
to the extraction logic to produce a final result.
Args:
source: A single file path/stream or a list of them.
response_model: A Pydantic model class for validating the extracted data.
vision: Whether to use vision mode (affecting how content is processed).
content: Optional extra content to prepend to the merged content.
completion_strategy: Strategy for handling completions.
Returns:
The parsed result from the LLM as validated by response_model.
"""
self._validate_dependencies(response_model, vision)
self.extra_content = content
Expand All @@ -173,18 +187,53 @@ def extract(
return self.extract_with_strategy(source, response_model, vision, completion_strategy)

try:
if self._skip_loading:
# Skip loading if flag is set (content from splitting)
unified_content = self._map_to_universal_format(source, vision)
if isinstance(source, list):
all_contents = []
for src in source:
loader = self.get_document_loader(src)
if loader is None:
raise ValueError(f"No suitable document loader found for source: {src}")
# Load the content (e.g. text, images, metadata)
loaded = loader.load(src)
# Map to a universal format that your extraction logic understands.
universal = self._map_to_universal_format(loaded, vision)
all_contents.append(universal)

# Merge the text contents with a clear separator.
merged_text = "\n\n--- Document Separator ---\n\n".join(
item.get("content", "") for item in all_contents
)
# Merge all image lists into one.
merged_images = []
for item in all_contents:
merged_images.extend(item.get("images", []))

merged_content = {
"content": merged_text,
"images": merged_images,
"metadata": {"num_documents": len(all_contents)}
}

# Optionally, prepend any extra content provided by the caller.
if content:
merged_content["content"] = content + "\n\n" + merged_content["content"]

return self._extract(merged_content, response_model, vision)
else:
# Normal loading path
loader = self.get_document_loader(source)
if not loader:
raise ValueError("No suitable document loader found for the input.")
loaded_content = loader.load(source)
unified_content = self._map_to_universal_format(loaded_content, vision)

return self._extract(unified_content, response_model, vision)
# Single source; use existing behavior.
if self._skip_loading:
# Skip loading if flag is set (content from splitting)
unified_content = self._map_to_universal_format(source, vision)
else:
# Normal loading path
loader = self.get_document_loader(source)
if not loader:
raise ValueError("No suitable document loader found for the input.")
loaded_content = loader.load(source)
unified_content = self._map_to_universal_format(loaded_content, vision)

return self._extract(unified_content, response_model, vision)

except IncompleteOutputException as e:
raise ValueError("Incomplete output received and FORBIDDEN strategy is set") from e
except Exception as e:
Expand Down
42 changes: 38 additions & 4 deletions tests/test_extractor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import asyncio
import os
from dotenv import load_dotenv
from extract_thinker.document_loader.document_loader_pdfplumber import DocumentLoaderPdfPlumber
from extract_thinker.extractor import Extractor
from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract
from extract_thinker.document_loader.document_loader_pypdf import DocumentLoaderPyPdf
from extract_thinker.llm import LLM, LLMEngine
from extract_thinker.models.completion_strategy import CompletionStrategy
from extract_thinker.models.contract import Contract
from tests.models.invoice import InvoiceContract
from tests.models.ChartWithContent import ChartWithContent
from tests.models.page_contract import ReportContract
Expand All @@ -16,9 +14,10 @@
import pytest
import numpy as np
from litellm import embedding
from extract_thinker.document_loader.document_loader_docling import DoclingConfig, DocumentLoaderDocling
from extract_thinker.document_loader.document_loader_docling import DocumentLoaderDocling
from tests.models.handbook_contract import HandbookContract
from extract_thinker.global_models import get_lite_model, get_big_model
from pydantic import BaseModel, Field


load_dotenv()
Expand Down Expand Up @@ -446,4 +445,39 @@ def test_extract_from_url_docling_and_gpt4o_mini():

# Assert: Verify that the extracted title matches the expected value.
expected_title = "BCOBS 2A.1 Restriction on marketing or providing an optional product for which a fee is payable"
assert result.title == expected_title
assert result.title == expected_title

def test_extract_from_multiple_sources():
"""
Test extracting from multiple sources (PDF and URL) in a single call.
Combines invoice data with handbook data using DocumentLoaderDocling.
"""
# Arrange
pdf_path = os.path.join(cwd, "tests", "files", "invoice.pdf")
url = "https://www.handbook.fca.org.uk/handbook/BCOBS/2A/?view=chapter"

extractor = Extractor()
docling_loader = DocumentLoaderDocling()
extractor.load_document_loader(docling_loader)
extractor.load_llm(get_big_model())

class CombinedData(BaseModel):
invoice_number: str
invoice_date: str
total_amount: float
handbook_title: str = Field(alias="title of the url, and not the invoice")

# Act
result: CombinedData = extractor.extract(
[pdf_path, url],
CombinedData,
)

# Assert
# Check invoice data
assert result.invoice_number == "00012"
assert result.invoice_date == "1/30/23"
assert result.total_amount == 1125

# Check handbook data
assert "FCA Handbook" in result.handbook_title, f"Expected title to contain 'FCA Handbook', but got: {result.handbook_title}"

0 comments on commit fd24ba3

Please sign in to comment.