Skip to content

Commit

Permalink
Merge pull request #273 from enoch3712/265-litellmsupports_vision-dis…
Browse files Browse the repository at this point in the history
…aggregation

litellm check vision removed. Exception checked after being raised
  • Loading branch information
enoch3712 authored Feb 21, 2025
2 parents 42413ab + c605539 commit 36a9048
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 52 deletions.
11 changes: 11 additions & 0 deletions extract_thinker/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class ExtractThinkerError(Exception):
"""Base exception class for ExtractThinker."""
pass

class VisionError(ExtractThinkerError):
"""Base class for vision-related errors."""
pass

class InvalidVisionDocumentLoaderError(VisionError):
"""Document loader does not support vision features."""
pass
42 changes: 15 additions & 27 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict, List, Optional, IO, Type, Union, get_origin
from instructor.batch import BatchJob
import uuid
import litellm
from pydantic import BaseModel
from extract_thinker.llm_engine import LLMEngine
from extract_thinker.concatenation_handler import ConcatenationHandler
Expand All @@ -17,7 +16,6 @@
from extract_thinker.document_loader.llm_interceptor import LlmInterceptor
from concurrent.futures import ThreadPoolExecutor, as_completed
from extract_thinker.batch_job import BatchJob

from extract_thinker.models.completion_strategy import CompletionStrategy
from extract_thinker.utils import (
add_classification_structure,
Expand All @@ -29,6 +27,11 @@
from copy import deepcopy
from extract_thinker.pagination_handler import PaginationHandler
from instructor.exceptions import IncompleteOutputException
from extract_thinker.exceptions import (
ExtractThinkerError,
InvalidVisionDocumentLoaderError
)
from extract_thinker.utils import is_vision_error, classify_vision_error

class Extractor:
BATCH_SUPPORTED_MODELS = [
Expand Down Expand Up @@ -161,7 +164,10 @@ def extract(
self.completion_strategy = completion_strategy

if vision:
self._handle_vision_mode(source)
try:
self._handle_vision_mode(source)
except ValueError as e:
raise InvalidVisionDocumentLoaderError(str(e))

if completion_strategy is not CompletionStrategy.FORBIDDEN:
return self.extract_with_strategy(source, response_model, vision, completion_strategy)
Expand All @@ -184,7 +190,11 @@ def extract(
except Exception as e:
if isinstance(e.args[0], IncompleteOutputException):
raise ValueError("Incomplete output received and FORBIDDEN strategy is set") from e
raise ValueError(f"Failed to extract from source: {str(e)}")

if vision & is_vision_error(e):
raise classify_vision_error(e, self.llm.model if self.llm else None)

raise ExtractThinkerError(f"Failed to extract from source: {str(e)}")

def _map_to_universal_format(
self,
Expand Down Expand Up @@ -334,11 +344,6 @@ def _build_classification_message_content(self, classifications: List[Classifica
Returns:
List of content items (text and images) for the message
"""
if not litellm.supports_vision(model=self.llm.model):
raise ValueError(
f"Model {self.llm.model} is not supported for vision, since it's not a vision model."
)

message_content = []
for classification in classifications:
if not classification.image:
Expand Down Expand Up @@ -442,9 +447,6 @@ def _classify_one_image_with_ref(
Classify doc_image_b64 as either matching or not matching
the classification's reference image. Return name/confidence.
"""
if not litellm.supports_vision(model=self.llm.model):
raise ValueError(f"Model {self.llm.model} does not support vision images.")

# Convert classification.reference image to base64 if needed:
if isinstance(ref_image, str) and os.path.isfile(ref_image):
ref_image_b64 = encode_image(ref_image)
Expand Down Expand Up @@ -512,9 +514,6 @@ def _classify_one_image_no_ref(
Minimal fallback if user didn't provide a reference image
but we still want a numeric confidence if doc_image matches the classification.
"""
if not litellm.supports_vision(model=self.llm.model):
raise ValueError(f"Model {self.llm.model} does not support vision images.")

messages = [
{
"role": "system",
Expand Down Expand Up @@ -945,11 +944,6 @@ def _extract(
for interceptor in self.llm_interceptors:
interceptor.intercept(self.llm)

if vision and not litellm.supports_vision(model=self.llm.model):
raise ValueError(
f"Model {self.llm.model} is not supported for vision."
)

# Build messages
messages = self._build_messages(self._build_message_content(content, vision), vision)

Expand Down Expand Up @@ -1238,13 +1232,7 @@ def _handle_vision_mode(self, source: Union[str, IO, list]) -> None:
self.document_loader.set_vision_mode(True)
return

# No document loader available, check if we can use LLM's vision capabilities
if not litellm.supports_vision(self.llm.model):
raise ValueError(
f"Model {self.llm.model} does not support vision. "
"Please provide a document loader or a model that supports vision."
)

# No document loader available, create a new DocumentLoaderLLMImage
self.document_loader = DocumentLoaderLLMImage(llm=self.llm)
self.document_loader.set_vision_mode(True)

Expand Down
13 changes: 4 additions & 9 deletions extract_thinker/image_splitter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import base64
import litellm
import instructor
from io import BytesIO
from typing import List, Any
from extract_thinker.models.classification import Classification
from extract_thinker.models.doc_group import DocGroups
from extract_thinker.models.doc_groups2 import DocGroups2
from extract_thinker.models.eager_doc_group import DocGroupsEager, EagerDocGroup
from extract_thinker.splitter import Splitter
from extract_thinker.llm import LLM

class ImageSplitter(Splitter):

def __init__(self, model: str):
if not litellm.supports_vision(model=model):
raise ValueError(f"Model {model} is not supported for ImageSplitter, since its not a vision model.")
self.model = model
self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.MD_JSON)
self.llm = LLM(model)

def encode_image(self, image):
"""
Expand Down Expand Up @@ -97,8 +94,7 @@ def belongs_to_same_document(self,
})

try:
response = self.client.chat.completions.create(
model=self.model,
response = self.llm.request(
messages=[
{
"role": "user",
Expand Down Expand Up @@ -195,8 +191,7 @@ def split_eager_doc_group(self, document: List[dict], classifications: List[Clas
})

try:
response = self.client.chat.completions.create(
model=self.model,
response = self.llm.request(
messages=[
{
"role": "user",
Expand Down
2 changes: 0 additions & 2 deletions extract_thinker/models/doc_groups.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from extract_thinker.models.doc_group import DocGroup


from typing import List


Expand Down
6 changes: 1 addition & 5 deletions extract_thinker/models/eager_doc_group.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
from dataclasses import dataclass
from typing import List
from pydantic import BaseModel

@dataclass
class EagerDocGroup:
pages: List[str]
classification: str


from typing import List
from pydantic import BaseModel


class DocGroup(BaseModel):
pages: List[int]
classification: str
Expand Down
12 changes: 4 additions & 8 deletions extract_thinker/text_splitter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import json
import instructor
from typing import List, Any
from extract_thinker.models.classification import Classification
from extract_thinker.models.doc_group import DocGroups
from extract_thinker.models.doc_groups2 import DocGroups2
from extract_thinker.models.eager_doc_group import DocGroupsEager, EagerDocGroup
from extract_thinker.splitter import Splitter
from litellm import completion
from extract_thinker.llm import LLM

class TextSplitter(Splitter):

def __init__(self, model: str):
self.model = model
self.client = instructor.from_litellm(completion, mode=instructor.Mode.MD_JSON)
self.llm = LLM(model)

def belongs_to_same_document(self,
obj1: Any,
Expand Down Expand Up @@ -55,8 +53,7 @@ def belongs_to_same_document(self,
}}"""

try:
response = self.client.chat.completions.create(
model=self.model,
response = self.llm.request(
messages=[
{
"role": "user",
Expand Down Expand Up @@ -127,8 +124,7 @@ def split_eager_doc_group(self, document: List[dict], classifications: List[Clas
}}"""

try:
response = self.client.chat.completions.create(
model=self.model,
response = self.llm.request(
messages=[
{
"role": "user",
Expand Down
27 changes: 26 additions & 1 deletion extract_thinker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import re
import yaml
from PIL import Image
import litellm
from pydantic import BaseModel
import typing
import os
from io import BytesIO
import sys
from typing import Optional, Any, Union, get_origin, get_args, List, Dict
from pydantic import BaseModel, create_model
from extract_thinker.exceptions import VisionError

def encode_image(image_source: Union[str, BytesIO, bytes, Image.Image]) -> str:
"""
Expand Down Expand Up @@ -535,4 +537,27 @@ def extract_thinking_json(thinking_text: str, response_model: type[BaseModel]) -
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}\nJSON string was: {json_str}")
except Exception as e:
raise ValueError(f"Failed to parse thinking output: {str(e)}\nInput text was: {thinking_text[:200]}...")
raise ValueError(f"Failed to parse thinking output: {str(e)}\nInput text was: {thinking_text[:200]}...")

def is_vision_error(error: Exception) -> bool:
if isinstance(error.args[0], litellm.BadRequestError):
return True
return False

def classify_vision_error(e: Exception, vision: bool) -> None:
"""
Examines the exception and, if the vision flag is set and the exception is of type
litellm.BadRequestError, re-raises it as a VisionError. Otherwise, it re-raises the original exception.
Args:
e: The caught exception.
vision: A flag indicating that this was a vision-related operation.
Raises:
VisionError: if vision is True and e is an instance of litellm.BadRequestError.
Otherwise, re-raises e.
"""
if vision and isinstance(e.args[0], litellm.BadRequestError):
raise VisionError(f"Make sure that the model you're using supports vision features: {e.args[0].message}") from e
else:
raise e

0 comments on commit 36a9048

Please sign in to comment.