diff --git a/extract_thinker/exceptions.py b/extract_thinker/exceptions.py new file mode 100644 index 0000000..9bc52a1 --- /dev/null +++ b/extract_thinker/exceptions.py @@ -0,0 +1,11 @@ +class ExtractThinkerError(Exception): + """Base exception class for ExtractThinker.""" + pass + +class VisionError(ExtractThinkerError): + """Base class for vision-related errors.""" + pass + +class InvalidVisionDocumentLoaderError(VisionError): + """Document loader does not support vision features.""" + pass \ No newline at end of file diff --git a/extract_thinker/extractor.py b/extract_thinker/extractor.py index 4e138e9..aae188c 100644 --- a/extract_thinker/extractor.py +++ b/extract_thinker/extractor.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, IO, Type, Union, get_origin from instructor.batch import BatchJob import uuid -import litellm from pydantic import BaseModel from extract_thinker.llm_engine import LLMEngine from extract_thinker.concatenation_handler import ConcatenationHandler @@ -17,7 +16,6 @@ from extract_thinker.document_loader.llm_interceptor import LlmInterceptor from concurrent.futures import ThreadPoolExecutor, as_completed from extract_thinker.batch_job import BatchJob - from extract_thinker.models.completion_strategy import CompletionStrategy from extract_thinker.utils import ( add_classification_structure, @@ -29,6 +27,11 @@ from copy import deepcopy from extract_thinker.pagination_handler import PaginationHandler from instructor.exceptions import IncompleteOutputException +from extract_thinker.exceptions import ( + ExtractThinkerError, + InvalidVisionDocumentLoaderError +) +from extract_thinker.utils import is_vision_error, classify_vision_error class Extractor: BATCH_SUPPORTED_MODELS = [ @@ -161,7 +164,10 @@ def extract( self.completion_strategy = completion_strategy if vision: - self._handle_vision_mode(source) + try: + self._handle_vision_mode(source) + except ValueError as e: + raise InvalidVisionDocumentLoaderError(str(e)) if completion_strategy is not CompletionStrategy.FORBIDDEN: return self.extract_with_strategy(source, response_model, vision, completion_strategy) @@ -184,7 +190,11 @@ def extract( except Exception as e: if isinstance(e.args[0], IncompleteOutputException): raise ValueError("Incomplete output received and FORBIDDEN strategy is set") from e - raise ValueError(f"Failed to extract from source: {str(e)}") + + if vision & is_vision_error(e): + raise classify_vision_error(e, self.llm.model if self.llm else None) + + raise ExtractThinkerError(f"Failed to extract from source: {str(e)}") def _map_to_universal_format( self, @@ -334,11 +344,6 @@ def _build_classification_message_content(self, classifications: List[Classifica Returns: List of content items (text and images) for the message """ - if not litellm.supports_vision(model=self.llm.model): - raise ValueError( - f"Model {self.llm.model} is not supported for vision, since it's not a vision model." - ) - message_content = [] for classification in classifications: if not classification.image: @@ -442,9 +447,6 @@ def _classify_one_image_with_ref( Classify doc_image_b64 as either matching or not matching the classification's reference image. Return name/confidence. """ - if not litellm.supports_vision(model=self.llm.model): - raise ValueError(f"Model {self.llm.model} does not support vision images.") - # Convert classification.reference image to base64 if needed: if isinstance(ref_image, str) and os.path.isfile(ref_image): ref_image_b64 = encode_image(ref_image) @@ -512,9 +514,6 @@ def _classify_one_image_no_ref( Minimal fallback if user didn't provide a reference image but we still want a numeric confidence if doc_image matches the classification. """ - if not litellm.supports_vision(model=self.llm.model): - raise ValueError(f"Model {self.llm.model} does not support vision images.") - messages = [ { "role": "system", @@ -945,11 +944,6 @@ def _extract( for interceptor in self.llm_interceptors: interceptor.intercept(self.llm) - if vision and not litellm.supports_vision(model=self.llm.model): - raise ValueError( - f"Model {self.llm.model} is not supported for vision." - ) - # Build messages messages = self._build_messages(self._build_message_content(content, vision), vision) @@ -1238,13 +1232,7 @@ def _handle_vision_mode(self, source: Union[str, IO, list]) -> None: self.document_loader.set_vision_mode(True) return - # No document loader available, check if we can use LLM's vision capabilities - if not litellm.supports_vision(self.llm.model): - raise ValueError( - f"Model {self.llm.model} does not support vision. " - "Please provide a document loader or a model that supports vision." - ) - + # No document loader available, create a new DocumentLoaderLLMImage self.document_loader = DocumentLoaderLLMImage(llm=self.llm) self.document_loader.set_vision_mode(True) diff --git a/extract_thinker/image_splitter.py b/extract_thinker/image_splitter.py index ee2c41d..355e767 100644 --- a/extract_thinker/image_splitter.py +++ b/extract_thinker/image_splitter.py @@ -1,6 +1,4 @@ import base64 -import litellm -import instructor from io import BytesIO from typing import List, Any from extract_thinker.models.classification import Classification @@ -8,14 +6,13 @@ from extract_thinker.models.doc_groups2 import DocGroups2 from extract_thinker.models.eager_doc_group import DocGroupsEager, EagerDocGroup from extract_thinker.splitter import Splitter +from extract_thinker.llm import LLM class ImageSplitter(Splitter): def __init__(self, model: str): - if not litellm.supports_vision(model=model): - raise ValueError(f"Model {model} is not supported for ImageSplitter, since its not a vision model.") self.model = model - self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.MD_JSON) + self.llm = LLM(model) def encode_image(self, image): """ @@ -97,8 +94,7 @@ def belongs_to_same_document(self, }) try: - response = self.client.chat.completions.create( - model=self.model, + response = self.llm.request( messages=[ { "role": "user", @@ -195,8 +191,7 @@ def split_eager_doc_group(self, document: List[dict], classifications: List[Clas }) try: - response = self.client.chat.completions.create( - model=self.model, + response = self.llm.request( messages=[ { "role": "user", diff --git a/extract_thinker/models/doc_groups.py b/extract_thinker/models/doc_groups.py index d17d515..8435d50 100644 --- a/extract_thinker/models/doc_groups.py +++ b/extract_thinker/models/doc_groups.py @@ -1,6 +1,4 @@ from extract_thinker.models.doc_group import DocGroup - - from typing import List diff --git a/extract_thinker/models/eager_doc_group.py b/extract_thinker/models/eager_doc_group.py index 32e2504..bc19926 100644 --- a/extract_thinker/models/eager_doc_group.py +++ b/extract_thinker/models/eager_doc_group.py @@ -1,16 +1,12 @@ from dataclasses import dataclass from typing import List +from pydantic import BaseModel @dataclass class EagerDocGroup: pages: List[str] classification: str - -from typing import List -from pydantic import BaseModel - - class DocGroup(BaseModel): pages: List[int] classification: str diff --git a/extract_thinker/text_splitter.py b/extract_thinker/text_splitter.py index 3446e48..2c2cc9f 100644 --- a/extract_thinker/text_splitter.py +++ b/extract_thinker/text_splitter.py @@ -1,18 +1,16 @@ -import json -import instructor from typing import List, Any from extract_thinker.models.classification import Classification from extract_thinker.models.doc_group import DocGroups from extract_thinker.models.doc_groups2 import DocGroups2 from extract_thinker.models.eager_doc_group import DocGroupsEager, EagerDocGroup from extract_thinker.splitter import Splitter -from litellm import completion +from extract_thinker.llm import LLM class TextSplitter(Splitter): def __init__(self, model: str): self.model = model - self.client = instructor.from_litellm(completion, mode=instructor.Mode.MD_JSON) + self.llm = LLM(model) def belongs_to_same_document(self, obj1: Any, @@ -55,8 +53,7 @@ def belongs_to_same_document(self, }}""" try: - response = self.client.chat.completions.create( - model=self.model, + response = self.llm.request( messages=[ { "role": "user", @@ -127,8 +124,7 @@ def split_eager_doc_group(self, document: List[dict], classifications: List[Clas }}""" try: - response = self.client.chat.completions.create( - model=self.model, + response = self.llm.request( messages=[ { "role": "user", diff --git a/extract_thinker/utils.py b/extract_thinker/utils.py index 9f08565..03aa865 100644 --- a/extract_thinker/utils.py +++ b/extract_thinker/utils.py @@ -3,6 +3,7 @@ import re import yaml from PIL import Image +import litellm from pydantic import BaseModel import typing import os @@ -10,6 +11,7 @@ import sys from typing import Optional, Any, Union, get_origin, get_args, List, Dict from pydantic import BaseModel, create_model +from extract_thinker.exceptions import VisionError def encode_image(image_source: Union[str, BytesIO, bytes, Image.Image]) -> str: """ @@ -535,4 +537,27 @@ def extract_thinking_json(thinking_text: str, response_model: type[BaseModel]) - except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format: {str(e)}\nJSON string was: {json_str}") except Exception as e: - raise ValueError(f"Failed to parse thinking output: {str(e)}\nInput text was: {thinking_text[:200]}...") \ No newline at end of file + raise ValueError(f"Failed to parse thinking output: {str(e)}\nInput text was: {thinking_text[:200]}...") + +def is_vision_error(error: Exception) -> bool: + if isinstance(error.args[0], litellm.BadRequestError): + return True + return False + +def classify_vision_error(e: Exception, vision: bool) -> None: + """ + Examines the exception and, if the vision flag is set and the exception is of type + litellm.BadRequestError, re-raises it as a VisionError. Otherwise, it re-raises the original exception. + + Args: + e: The caught exception. + vision: A flag indicating that this was a vision-related operation. + + Raises: + VisionError: if vision is True and e is an instance of litellm.BadRequestError. + Otherwise, re-raises e. + """ + if vision and isinstance(e.args[0], litellm.BadRequestError): + raise VisionError(f"Make sure that the model you're using supports vision features: {e.args[0].message}") from e + else: + raise e \ No newline at end of file