Skip to content

Commit

Permalink
add better bettertransformers support (#509)
Browse files Browse the repository at this point in the history
* add better bettertransformers support

* Update libs/infinity_emb/infinity_emb/transformer/acceleration.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* update spelling

* update torch compile

* fmt

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
  • Loading branch information
michaelfeil and greptile-apps[bot] authored Jan 3, 2025
1 parent c69f927 commit ef4c424
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 36 deletions.
8 changes: 4 additions & 4 deletions libs/infinity_emb/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,18 @@ spell_fix:
poetry run codespell --toml pyproject.toml -w

benchmark_embed: tests/data/benchmark/benchmark_embed.json
ab -n 10 -c 10 -l -s 480 \
ab -n 50 -c 50 -l -s 480 \
-T 'application/json' \
-p $< \
http://127.0.0.1:7997/embeddings
# sudo apt-get apache2-utils
# sudo apt-get install apache2-utils

benchmark_embed_vision: tests/data/benchmark/benchmark_embed_image.json
ab -n 100 -c 50 -l -s 480 \
ab -n 50 -c 50 -l -s 480 \
-T 'application/json' \
-p $< \
http://127.0.0.1:7997/embeddings
# sudo apt-get apache2-utils
# sudo apt-get install apache2-utils

# Generate CLI v2 documentation
cli_v2_docs:
Expand Down
20 changes: 19 additions & 1 deletion libs/infinity_emb/infinity_emb/transformer/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import os
from typing import TYPE_CHECKING

from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TORCH
from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TORCH, CHECK_TRANSFORMERS
from infinity_emb.primitives import Device

if CHECK_OPTIMUM.is_available:
from optimum.bettertransformer import ( # type: ignore[import-untyped]
BetterTransformer,
BetterTransformerManager,
)

if CHECK_TORCH.is_available:
Expand All @@ -19,6 +20,9 @@
# allow TF32 for better performance
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if CHECK_TRANSFORMERS.is_available:
from transformers import AutoConfig # type: ignore[import-untyped]


if TYPE_CHECKING:
from logging import Logger
Expand All @@ -28,6 +32,20 @@
from infinity_emb.args import EngineArgs


def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool:
"""verifies if attempting conversion to bettertransformers should be checked."""
if not engine_args.bettertransformer:
return False

config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
)

return config.model_type in BetterTransformerManager.MODEL_MAPPING


def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", logger: "Logger"):
if not engine_args.bettertransformer:
return model
Expand Down
6 changes: 1 addition & 5 deletions libs/infinity_emb/infinity_emb/transformer/audio/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,14 @@ class TorchAudioModel(BaseAudioEmbedModel):
def __init__(self, *, engine_args: EngineArgs):
CHECK_TORCH.mark_required()
CHECK_TRANSFORMERS.mark_required()

self.model = AutoModel.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
# attn_implementation="eager" if engine_args.bettertransformer else None,
)

# self.model = to_bettertransformer(
# self.model,
# engine_args,
# logger,
# )
self.processor = AutoProcessor.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
Expand Down
21 changes: 13 additions & 8 deletions libs/infinity_emb/infinity_emb/transformer/classifier/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from infinity_emb.args import EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.transformer.abstract import BaseClassifer
from infinity_emb.transformer.acceleration import to_bettertransformer
from infinity_emb.transformer.acceleration import (
to_bettertransformer,
check_if_bettertransformer_possible,
)
from infinity_emb.transformer.quantization.interface import quant_interface
from infinity_emb.primitives import Device

Expand All @@ -23,7 +26,8 @@ def __init__(
) -> None:
CHECK_TRANSFORMERS.mark_required()
model_kwargs = {}
if engine_args.bettertransformer:
attempt_bt = check_if_bettertransformer_possible(engine_args)
if engine_args.bettertransformer and attempt_bt:
model_kwargs["attn_implementation"] = "eager"
ls = engine_args._loading_strategy
assert ls is not None
Expand All @@ -41,17 +45,18 @@ def __init__(
model_kwargs=model_kwargs,
)

self._pipe.model = to_bettertransformer(
self._pipe.model,
engine_args,
logger,
)

if ls.quantization_dtype is not None:
self._pipe.model = quant_interface( # TODO: add ls.quantization_dtype and ls.placement
self._pipe.model, engine_args.dtype, device=Device[self._pipe.model.device.type]
)

if engine_args.bettertransformer and attempt_bt:
self._pipe.model = to_bettertransformer(
self._pipe.model, # type: ignore
engine_args,
logger,
)

if engine_args.compile:
logger.info("using torch.compile(dynamic=True)")
self._pipe.model = torch.compile(self._pipe.model, dynamic=True)
Expand Down
20 changes: 12 additions & 8 deletions libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ class CrossEncoder: # type: ignore[no-redef]
from torch import Tensor


from infinity_emb.transformer.acceleration import to_bettertransformer
from infinity_emb.transformer.acceleration import (
to_bettertransformer,
check_if_bettertransformer_possible,
)

__all__ = [
"CrossEncoderPatched",
Expand All @@ -42,7 +45,8 @@ def __init__(self, *, engine_args: EngineArgs):
CHECK_SENTENCE_TRANSFORMERS.mark_required()

model_kwargs = {}
if engine_args.bettertransformer:
attempt_bt = check_if_bettertransformer_possible(engine_args)
if engine_args.bettertransformer and attempt_bt:
model_kwargs["attn_implementation"] = "eager"

ls = engine_args._loading_strategy
Expand All @@ -66,12 +70,12 @@ def __init__(self, *, engine_args: EngineArgs):

self._infinity_tokenizer = copy.deepcopy(self.tokenizer)
self.model.eval() # type: ignore

self.model = to_bettertransformer(
self.model, # type: ignore
engine_args,
logger,
)
if engine_args.bettertransformer and attempt_bt:
self.model = to_bettertransformer(
self.model, # type: ignore
engine_args,
logger,
)

self.model.to(ls.loading_dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from infinity_emb.log_handler import logger
from infinity_emb.primitives import Device
from infinity_emb.transformer.abstract import BaseEmbedder
from infinity_emb.transformer.acceleration import to_bettertransformer
from infinity_emb.transformer.acceleration import (
to_bettertransformer,
check_if_bettertransformer_possible,
)
from infinity_emb.transformer.quantization.interface import (
quant_embedding_decorator,
quant_interface,
Expand Down Expand Up @@ -56,7 +59,8 @@ def __init__(self, *, engine_args=EngineArgs):
CHECK_SENTENCE_TRANSFORMERS.mark_required()

model_kwargs = {}
if engine_args.bettertransformer:
attempt_bt = check_if_bettertransformer_possible(engine_args)
if engine_args.bettertransformer and attempt_bt:
model_kwargs["attn_implementation"] = "eager"

ls = engine_args._loading_strategy
Expand Down Expand Up @@ -88,12 +92,12 @@ def __init__(self, *, engine_args=EngineArgs):
self._infinity_tokenizer = copy.deepcopy(fm.tokenizer)
self.eval()
self.engine_args = engine_args

fm.auto_model = to_bettertransformer(
fm.auto_model,
engine_args,
logger,
)
if engine_args.bettertransformer and attempt_bt:
fm.auto_model = to_bettertransformer(
fm.auto_model,
engine_args,
logger,
)

fm.to(ls.loading_dtype)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

device = Device.cpu if torch.backends.mps.is_available() else Device.auto

SHOULD_TORCH_COMPILE = sys.platform == "linux" and sys.version_info < (3, 12)
SHOULD_TORCH_COMPILE = (
sys.platform == "linux" and sys.version_info < (3, 12) and torch.cuda.is_available()
)


def test_crossencoder():
Expand Down Expand Up @@ -43,7 +45,9 @@ def test_crossencoder():
def test_patched_crossencoder_vs_sentence_transformers():
model = CrossEncoderPatched(
engine_args=EngineArgs(
model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1", compile=True, device=device
model_name_or_path="mixedbread-ai/mxbai-rerank-xsmall-v1",
compile=SHOULD_TORCH_COMPILE,
device=device,
)
)
model_unpatched = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
Expand Down

0 comments on commit ef4c424

Please sign in to comment.