Skip to content

Commit

Permalink
add flashattention v2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 15, 2023
1 parent 5055604 commit 1a05535
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
52 changes: 27 additions & 25 deletions optimum_benchmark/backends/pytorch/backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import logging
from logging import getLogger
from typing import TYPE_CHECKING, Any, Callable, Dict, List

Expand All @@ -19,6 +20,9 @@
# bachend logger
LOGGER = getLogger("pytorch")

# disable numexpr.utils logger
getLogger("numexpr.utils").setLevel(logging.CRITICAL)


class PyTorchBackend(Backend[PyTorchConfig]):
NAME: str = "pytorch"
Expand Down Expand Up @@ -64,14 +68,9 @@ def configure(self, config: PyTorchConfig) -> None:
self.pretrained_model.eval()

# BetterTransformer
if self.config.bettertransformer:
LOGGER.info("\t+ Using optimum.bettertransformer")
from optimum.bettertransformer import BetterTransformer

self.pretrained_model = BetterTransformer.transform(
self.pretrained_model,
keep_original_model=False,
)
if self.config.to_bettertransformer:
LOGGER.info("\t+ Enabling BetterTransformer")
self.pretrained_model.to_bettertransformer()

# Compile model
if self.config.torch_compile:
Expand Down Expand Up @@ -100,27 +99,26 @@ def configure(self, config: PyTorchConfig) -> None:
self.pretrained_model = get_peft_model(self.pretrained_model, peft_config=peft_config)

def load_model_from_pretrained(self) -> None:
# attempting inline quantization if possible
if self.config.quantization_scheme == "gptq" and self.config.quantization_config:
# iniline quantization or quantization config modification
if self.config.quantization_scheme == "gptq":
LOGGER.info("\t+ Processing GPTQ config")
from transformers import GPTQConfig

self.quantization_config = GPTQConfig(**self.config.quantization_config)
elif self.config.quantization_scheme == "awq" and self.config.quantization_config:
elif self.config.quantization_scheme == "awq":
LOGGER.info("\t+ Processing AWQ config")
from transformers import AwqConfig

self.quantization_config = AwqConfig(**self.config.quantization_config)
elif self.config.quantization_scheme == "bnb":
LOGGER.info("\t+ Processing BitsAndBytesConfig")
LOGGER.info("\t+ Processing BitsAndBytes config")
from transformers import BitsAndBytesConfig

self.quantization_config = self.config.quantization_config.copy()
if self.quantization_config.get("bnb_4bit_compute_dtype", None) is not None:
self.quantization_config["bnb_4bit_compute_dtype"] = getattr(
torch, self.quantization_config["bnb_4bit_compute_dtype"]
)
LOGGER.info(f"\t+ Using bnb_4bit_compute_dtype: {self.quantization_config['bnb_4bit_compute_dtype']}")
self.quantization_config = BitsAndBytesConfig(**self.quantization_config)
else:
self.quantization_config = None
Expand All @@ -135,8 +133,8 @@ def load_model_from_pretrained(self) -> None:
)
if self.config.device_map is None:
LOGGER.info(f"\t+ Moving diffusion pipeline to device: {self.device}")
# Diffusers does not support loading with torch.device context manager
self.pretrained_model.to(self.device)

elif self.config.device_map is not None:
LOGGER.info(f"\t+ Loading model with device_map: {self.config.device_map}")
self.pretrained_model = self.automodel_class.from_pretrained(
Expand All @@ -146,18 +144,17 @@ def load_model_from_pretrained(self) -> None:
**self.automodel_kwargs,
**self.hub_kwargs,
)
elif hasattr(self.pretrained_config, "quantization_config"):
LOGGER.info("\t+ Loading quantized model")
elif hasattr(self.pretrained_config, "quantization_config") or self.quantization_config is not None:
LOGGER.info("\t+ Loading model with low cpu memory usage")
self.pretrained_model = self.automodel_class.from_pretrained(
self.model,
device_map=self.device,
low_cpu_mem_usage=True,
low_cpu_memory_usage=True,
torch_dtype=self.torch_dtype,
**self.automodel_kwargs,
**self.hub_kwargs,
)
).to(self.device)
else:
LOGGER.info(f"\t+ Loading model on device: {self.device}")
LOGGER.info(f"\t+ Loading model directly on device: {self.device}")
with self.device:
self.pretrained_model = self.automodel_class.from_pretrained(
self.model,
Expand All @@ -168,13 +165,17 @@ def load_model_from_pretrained(self) -> None:

@property
def automodel_kwargs(self) -> Dict[str, Any]:
kwargs = {}

if self.quantization_config is not None:
return {"quantization_config": self.quantization_config}
else:
return {}
kwargs["quantization_config"] = self.quantization_config

if self.config.use_flash_attention_2:
kwargs["use_flash_attention_2"] = True

return kwargs

def load_model_from_config(self) -> None:
# TODO: create no_weights tests
from accelerate import init_empty_weights

LOGGER.info("\t+ Initializing empty weights model on device: meta")
Expand Down Expand Up @@ -238,10 +239,11 @@ def load_model_from_config(self) -> None:

def prepare_for_inference(self, input_shapes: Dict[str, int], **kwargs) -> None:
super().prepare_for_inference(input_shapes=input_shapes, **kwargs)

if (self.config.quantization_scheme == "gptq" and self.config.quantization_config.get("desc_act", None)) or (
hasattr(self.pretrained_config, "quantization_config")
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
and self.pretrained_config.quantization_config.get("desc_act", None)
and self.pretrained_config.quantization_config.get("quant_method", None) == "gptq"
):
LOGGER.info("\t+ Setting GPTQ's max_input_length")
from auto_gptq import exllama_set_max_input_length
Expand Down
3 changes: 2 additions & 1 deletion optimum_benchmark/backends/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class PyTorchConfig(BackendConfig):
torch_compile_config: Dict[str, Any] = field(default_factory=dict)

# optimization options
bettertransformer: bool = False
to_bettertransformer: bool = False
use_flash_attention_2: bool = False

# quantization options
quantization_scheme: Optional[str] = None
Expand Down

0 comments on commit 1a05535

Please sign in to comment.