From 0986b6345af87a8fe0dab1e29b17a3cac7cc1bce Mon Sep 17 00:00:00 2001 From: Ruheena Suhani Shaik Date: Fri, 24 Jan 2025 06:03:56 +0200 Subject: [PATCH] Supports Bitsandbytes development on HPU (#117) --- optimum/habana/accelerate/utils/modeling.py | 52 ++++ optimum/habana/quantizers/bitsandbytes.py | 265 ++++++++++++++++++ optimum/habana/transformers/modeling_utils.py | 17 ++ .../models/llama/modeling_llama.py | 12 +- tests/test_bnb_inference.py | 66 +++++ tests/test_bnb_qlora.py | 152 ++++++++++ 6 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 optimum/habana/accelerate/utils/modeling.py create mode 100644 optimum/habana/quantizers/bitsandbytes.py create mode 100644 tests/test_bnb_inference.py create mode 100644 tests/test_bnb_qlora.py diff --git a/optimum/habana/accelerate/utils/modeling.py b/optimum/habana/accelerate/utils/modeling.py new file mode 100644 index 0000000000..2dbbdb951e --- /dev/null +++ b/optimum/habana/accelerate/utils/modeling.py @@ -0,0 +1,52 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Device similarity check compatible with hpu +""" + +import torch + + +def gaudi_check_device_same(first_device, second_device): + """ + Copied from https://github.com/huggingface/accelerate/blob/6b2d968897c91bc3f96274b4679d84e9950ad908/src/accelerate/utils/modeling.py#L50 + difference is addition of HPU device checks + + Args: + first_device (`torch.device`): + First device to check + second_device (`torch.device`): + Second device to check + """ + if first_device.type != second_device.type: + return False + + if first_device.type == "cuda" and first_device.index is None: + # In case the first_device is a cuda device and have + # the index attribute set to `None`, default it to `0` + first_device = torch.device("cuda", index=0) + + elif first_device.type == "hpu" and first_device.index is None: + first_device = torch.device("hpu", index=0) + + if second_device.type == "cuda" and second_device.index is None: + # In case the second_device is a cuda device and have + # the index attribute set to `None`, default it to `0` + second_device = torch.device("cuda", index=0) + + elif second_device.type == "hpu" and second_device.index is None: + second_device = torch.device("hpu", index=0) + + return first_device == second_device diff --git a/optimum/habana/quantizers/bitsandbytes.py b/optimum/habana/quantizers/bitsandbytes.py new file mode 100644 index 0000000000..ee56b55d53 --- /dev/null +++ b/optimum/habana/quantizers/bitsandbytes.py @@ -0,0 +1,265 @@ +from functools import lru_cache +from typing import Any, Dict, List, Optional + +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import Conv1D +from transformers.quantizers.quantizers_utils import get_module_from_name +from transformers.utils import ( + ACCELERATE_MIN_VERSION, + get_available_devices, + is_accelerate_available, + is_bitsandbytes_multi_backend_available, + is_ipex_available, + is_torch_available, + logging, +) +from transformers.utils.import_utils import _is_package_available + + +if is_torch_available(): + import torch + +_bitsandbytes_available = _is_package_available("bitsandbytes") +logger = logging.get_logger(__name__) + + +def gaudi_bitsandbytesconfig_post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/utils/quantization_config.py#L430 + Only difference is removed check on bitsandbytes version + """ + if not isinstance(self.load_in_4bit, bool): + raise TypeError("load_in_4bit must be a boolean") + + if not isinstance(self.load_in_8bit, bool): + raise TypeError("load_in_8bit must be a boolean") + + if not isinstance(self.llm_int8_threshold, float): + raise TypeError("llm_int8_threshold must be a float") + + if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + raise TypeError("llm_int8_skip_modules must be a list of strings") + if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): + raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean") + + if not isinstance(self.llm_int8_has_fp16_weight, bool): + raise TypeError("llm_int8_has_fp16_weight must be a boolean") + + if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise TypeError("bnb_4bit_compute_dtype must be torch.dtype") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise TypeError("bnb_4bit_quant_type must be a string") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise TypeError("bnb_4bit_use_double_quant must be a boolean") + + +@lru_cache() +def gaudi_is_bitsandbytes_available(): + """ + Copied from https://github.com/huggingface/transformers/blob/5523e38b553ff6c46b04d2376870fcd842feeecc/src/transformers/utils/import_utils.py#L871 + Only difference is that CUDA related checks are removed. + """ + if not is_torch_available() or not _bitsandbytes_available: + return False + + # Newer versions of `bitsandbytes` can be imported on systems without CUDA. + return True + + +def gaudi_validate_bnb_backend_availability(raise_exception=False): + """ + Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not. + Copied from https://github.com/huggingface/transformers/blob/5523e38b553ff6c46b04d2376870fcd842feeecc/src/transformers/integrations/bitsandbytes.py#L545 + Only difference is that CUDA related functions calls are deleted. + """ + if is_bitsandbytes_multi_backend_available(): + return _gaudi_validate_bnb_multi_backend_availability(raise_exception) + + +def _gaudi_validate_bnb_multi_backend_availability(raise_exception): + """ + Copied https://github.com/huggingface/transformers/blob/5523e38b553ff6c46b04d2376870fcd842feeecc/src/transformers/integrations/bitsandbytes.py#L484 + Only difference is addition of check for HPU + """ + import bitsandbytes as bnb + + bnb_supported_devices = getattr(bnb, "supported_torch_devices", set()) + available_devices = get_available_devices() + + if "hpu" in bnb_supported_devices: + logger.debug("Multi-backend validation successful.") + return True + + if available_devices == {"cpu"} and not is_ipex_available(): + from importlib.util import find_spec + + if find_spec("intel_extension_for_pytorch"): + logger.warning( + "You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible." + ) + + available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment + + if not available_devices.intersection(bnb_supported_devices): + if raise_exception: + bnb_supported_devices_with_info = set( # noqa: C401 + '"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)' + if device == "cpu" + else device + for device in bnb_supported_devices + ) + err_msg = ( + f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. " + "Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" + ) + + logger.error(err_msg) + raise RuntimeError(err_msg) + + logger.warning("No supported devices found for bitsandbytes multi-backend.") + return False + + logger.debug("Multi-backend validation successful.") + return True + + +def gaudi_validate_environment(self, *args, **kwargs): + """ + Copied from https://github.com/huggingface/transformers/blob/5523e38b553ff6c46b04d2376870fcd842feeecc/src/transformers/quantizers/quantizer_bnb_4bit.py#L68 + Only difference is deletion of bitsandbytes version checks + """ + if not is_accelerate_available(): + raise ImportError( + f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`" + ) + if not gaudi_is_bitsandbytes_available(): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() + gaudi_validate_bnb_backend_availability(raise_exception=True) + + if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): + raise ValueError( + "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_lm_head = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: + pass + elif "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + +def gaudi_create_quantized_param( + self, + model: "PreTrainedModel", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, +): + """ + Copied from https://github.com/huggingface/transformers/blob/62c60a30181a65e1a3a7f19c3055a240a6a21335/src/transformers/quantizers/quantizer_bnb_4bit.py#L138 + only diiference is addition of HPU device + """ + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if tensor_name == "bias": + if param_value is None: + new_value = old_value.to(target_device) + else: + new_value = param_value.to(target_device) + + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + return + + if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): + raise ValueError("this function only loads `Linear4bit components`") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + # construct `new_value` for the module._parameters[tensor_name]: + if self.pre_quantized: + # 4bit loading. Collecting components for restoring quantized weight + # This can be expanded to make a universal call for any quantized weight loading + + if not self.is_serializable: + raise ValueError( + "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( + param_name + ".quant_state.bitsandbytes__nf4" not in state_dict + ): + raise ValueError( + f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." + ) + + quantized_stats = {} + for k, v in state_dict.items(): + if param_name + "." in k: + quantized_stats[k] = v + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) + + param_kwargs = {} + if self.is_bnb_supports_quant_storage_module: + param_kwargs["module"] = module + + new_value = bnb.nn.Params4bit.from_prequantized( + data=param_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=target_device, + **param_kwargs, + ) + else: + if target_device == "hpu": + new_value = param_value.to("hpu") + else: + new_value = param_value.to("cpu") + + # Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization. + # Since weights are saved in the correct "orientation", we skip transposing when loading. + if issubclass(module.source_cls, Conv1D): + new_value = new_value.T + + kwargs = old_value.__dict__ + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 27f4de8820..d807e4527b 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -18,6 +18,14 @@ import transformers.utils.fx from ..accelerate.utils import extract_model_from_parallel +from ..accelerate.utils.modeling import gaudi_check_device_same +from ..quantizers.bitsandbytes import ( + gaudi_bitsandbytesconfig_post_init, + gaudi_create_quantized_param, + gaudi_is_bitsandbytes_available, + gaudi_validate_bnb_backend_availability, + gaudi_validate_environment, +) from .generation import ( GaudiGenerationConfig, GaudiGenerationMixin, @@ -276,6 +284,15 @@ def adapt_transformers_to_gaudi(): accelerate.utils.extract_model_from_parallel = extract_model_from_parallel accelerate.utils.other.extract_model_from_parallel = extract_model_from_parallel accelerate.accelerator.extract_model_from_parallel = extract_model_from_parallel + accelerate.utils.modeling.check_device_same = gaudi_check_device_same + + transformers.utils.quantization_config.BitsAndBytesConfig.post_init = gaudi_bitsandbytesconfig_post_init + transformers.utils.import_utils.is_bitsandbytes_available = gaudi_is_bitsandbytes_available + transformers.utils.is_bitsandbytes_available = gaudi_is_bitsandbytes_available + transformers.quantizers.quantizer_bnb_4bit.is_bitsandbytes_available = gaudi_is_bitsandbytes_available + transformers.integrations.bitsandbytes.validate_bnb_backend_availability = gaudi_validate_bnb_backend_availability + transformers.quantizers.quantizer_bnb_4bit.Bnb4BitHfQuantizer.validate_environment = gaudi_validate_environment + transformers.quantizers.quantizer_bnb_4bit.Bnb4BitHfQuantizer.create_quantized_param = gaudi_create_quantized_param # models that support symbolic tracing should be added to this list models_with_tracing_support = [] diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0afcfbe05a..0cd1a9ccca 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -652,10 +652,18 @@ def pre_attn_forward( else: if past_key_value is None: past_key = torch.zeros( - key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + key_states.shape, + dtype=self.get_k_proj_weight_dtype() + if self.get_k_proj_weight_dtype() != torch.uint8 + else key_states.dtype, + device=key_states.device, ) past_value = torch.zeros( - key_states.shape, dtype=self.get_k_proj_weight_dtype(), device=key_states.device + key_states.shape, + dtype=self.get_k_proj_weight_dtype() + if self.get_k_proj_weight_dtype() != torch.uint8 + else key_states.dtype, + device=key_states.device, ) # Return list instead of tuple past_key_value = [past_key, past_value] diff --git a/tests/test_bnb_inference.py b/tests/test_bnb_inference.py new file mode 100644 index 0000000000..9218869669 --- /dev/null +++ b/tests/test_bnb_inference.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright 2022 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os + +import torch +from habana_frameworks.torch.hpu import wrap_in_hpu_graph +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +from optimum.habana.transformers import modeling_utils + + +modeling_utils.adapt_transformers_to_gaudi() + +assert os.environ.get("GAUDI2_CI", "0") == "1", "Execution does not support on Gaudi1" + +MODEL_ID = "meta-llama/Llama-3.2-1B" + + +def get_model(token: str): + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, quantization_config=nf4_config, device_map={"": "hpu"}, torch_dtype=torch.bfloat16, token=token.value + ) + + return model + + +def test_nf4_quantization_inference(token: str): + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=token.value) + + model = get_model(token) + + generation_config = copy.deepcopy(model.generation_config) + generation_config.max_new_tokens = 20 + generation_config.use_cache = True + generation_config.use_flash_attention = True + + model = wrap_in_hpu_graph(model) + + input_text = "Hello my name is" + inputs = tokenizer(input_text, return_tensors="pt").to(device="hpu") + + torch.manual_seed(42) + outputs = model.generate(**inputs, generation_config=generation_config, hpu_graphs=True, lazy_mode=True) + decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) + + assert decoded_output == "Hello my name is Marlene and I am 36 years old. I am a very happy person, I love to" diff --git a/tests/test_bnb_qlora.py b/tests/test_bnb_qlora.py new file mode 100644 index 0000000000..ac33a74ee1 --- /dev/null +++ b/tests/test_bnb_qlora.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2022 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess + +import pytest +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForLanguageModeling + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments +from optimum.habana.transformers import modeling_utils + + +modeling_utils.adapt_transformers_to_gaudi() + +assert os.environ.get("GAUDI2_CI", "0") == "1", "Execution does not support on Gaudi1" +try: + import sys + + subprocess.check_call([sys.executable, "-m", "pip", "install", "peft==0.12.0"]) + from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +except subprocess.CalledProcessError: + pytest.fail("Failed to install peft==0.12.0") + +MODEL_ID = "meta-llama/Llama-3.2-1B" + + +def print_model_size(model): + """ + Prints the model size in GB. + """ + model_size = sum(p.numel() * p.element_size() for p in model.parameters()) + model_size_GB = model_size / (1024**3) + print(f" Model size : {model_size_GB} GB") + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def get_data(tokenizer, dataset_name): + dataset = load_dataset(dataset_name) + dataset = dataset.shuffle(seed=42) + data = dataset.map(lambda example: tokenizer(example["text"]), batched=True) + split_data = data["train"].train_test_split(test_size=0.1, seed=42) + + return split_data + + +def get_model(token: str): + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, quantization_config=nf4_config, device_map={"": "hpu"}, torch_dtype=torch.bfloat16, token=token.value + ) + + return model + + +def test_nf4_quantization_inference(token: str): + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=token.value) + # needed for llama tokenizer + tokenizer.pad_token = tokenizer.eos_token + + model = get_model(token) + model.gradient_checkpointing_enable() + print_model_size(model) + + model = prepare_model_for_kbit_training(model) + + config = LoraConfig( + r=4, + lora_alpha=64, + target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + print_trainable_parameters(model) + + data = get_data(tokenizer, dataset_name="tatsu-lab/alpaca") + + gaudi_config = GaudiConfig( + use_fused_adam=True, + use_fused_clip_norm=True, + use_torch_autocast=True, + ) + + training_args = GaudiTrainingArguments( + evaluation_strategy="steps", + per_device_train_batch_size=8, + per_device_eval_batch_size=8, + gradient_accumulation_steps=2, + max_steps=5, + eval_steps=3, + warmup_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir="results", + lr_scheduler_type="linear", + use_habana=True, + use_lazy_mode=True, + pipelining_fwd_bwd=True, + ) + + trainer = GaudiTrainer( + model=model, + train_dataset=data["train"], + eval_dataset=data["test"], + gaudi_config=gaudi_config, + args=training_args, + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), + ) + model.config.use_cache = False # silence the warnings. Please re-enable for inference! + + trainer.train() + eval_loss = trainer.evaluate()["eval_loss"] + + expected_eval_loss = 1.638 + + assert abs(eval_loss - expected_eval_loss) < 5e-2