-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HIGGS Quantization Support #34997
base: main
Are you sure you want to change the base?
HIGGS Quantization Support #34997
Changes from 14 commits
08b347c
14a0c82
1c5b9e7
9f2ef77
0ff58c3
2e9adc6
b6bad71
c2bcf39
a1e7b35
8f1a0a6
120f360
fdb71a5
127c5f0
947a53d
e6ddc41
0de97f1
1f08cb0
ed369be
96023ab
60ce44b
8142443
1d636ac
66ece1d
5a68cd6
1cb9f0c
257c39b
f82d1a3
b747980
398d5b1
0ede69c
ebe6766
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
||
from .base import HfQuantizer | ||
from .quantizers_utils import get_module_from_name | ||
|
||
|
||
if TYPE_CHECKING: | ||
from ..modeling_utils import PreTrainedModel | ||
|
||
from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging | ||
from ..utils.quantization_config import QuantizationConfigMixin | ||
|
||
|
||
if is_torch_available(): | ||
import torch | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
# Finds the parent of a node module named "name" | ||
def find_parent(model, name): | ||
module_tree = name.split(".")[:-1] | ||
parent = model | ||
for m in module_tree: | ||
parent = parent._modules[m] | ||
return parent | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry if i'm mistaken, I don't believe we use this function anywhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed the unused function. Thanks! |
||
|
||
class HiggsHfQuantizer(HfQuantizer): | ||
""" | ||
Quantizer of the HIGGS method. Enables the loading of prequantized models. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a small nit, I think we should specify that it enables both loading and quantization of models because there are other quantizers that only enable loading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added |
||
|
||
requires_calibration = False | ||
requires_parameters_quantization = True | ||
required_packages = ["flute-kernel", "fast_hadamard_transform"] | ||
optimum_quantizer = None | ||
BlackSamorez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||
super().__init__(quantization_config, **kwargs) | ||
self.quantization_config = quantization_config | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not is_accelerate_available(): | ||
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`") | ||
|
||
if not is_flute_available(): | ||
raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel`") | ||
|
||
if not is_hadamard_available(): | ||
raise ImportError( | ||
"Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`" | ||
) | ||
|
||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
if torch_dtype is None: | ||
if torch.cuda.is_available(): | ||
torch_dtype = torch.float16 | ||
logger.info( | ||
"CUDA available. Assuming HIGGS inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually." | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"HIGGS quantization is only supported on GPU. Please use a different quantizer." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's check if cuda is available in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return torch_dtype | ||
|
||
def create_quantized_param( | ||
self, | ||
model: "PreTrainedModel", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
target_device: "torch.device", | ||
state_dict: Dict[str, Any], | ||
unexpected_keys: Optional[List[str]] = None, | ||
): | ||
from ..integrations import quantize_with_higgs | ||
|
||
""" | ||
Quantizes weights into weight and weight_scale | ||
""" | ||
flute_dict = quantize_with_higgs( | ||
param_value.to(target_device), | ||
self.quantization_config.bits, | ||
self.quantization_config.p, | ||
) | ||
|
||
del param_value | ||
|
||
module, tensor_name = get_module_from_name(model, param_name) | ||
for key, value in flute_dict.items(): | ||
if key in module._parameters: | ||
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False) | ||
elif key in module._buffers: | ||
module._buffers[key] = torch.nn.Buffer(value) | ||
else: | ||
raise ValueError(f"Unexpected key {key} in module {module}") | ||
|
||
if unexpected_keys is not None and param_name in unexpected_keys: | ||
unexpected_keys.remove(param_name) | ||
|
||
def _process_model_before_weight_loading( | ||
self, | ||
model: "PreTrainedModel", | ||
**kwargs, | ||
): | ||
from ..integrations import replace_with_higgs_linear | ||
|
||
replace_with_higgs_linear( | ||
model, | ||
quantization_config=self.quantization_config, | ||
) | ||
model.config.quantization_config = self.quantization_config | ||
|
||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): | ||
import flute.utils | ||
|
||
from ..integrations import HiggsLinear | ||
|
||
flute_workspaces = {} | ||
for name, module in model.named_modules(): | ||
if isinstance(module, HiggsLinear): | ||
if module.weight.device not in flute_workspaces: | ||
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk( | ||
device=module.weight.device | ||
) | ||
module.workspace = flute_workspaces[module.weight.device] | ||
Comment on lines
+155
to
+158
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a comment on what we are doing here ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments to this and possible repacking happening afterwards |
||
|
||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: | ||
from ..integrations import HiggsLinear | ||
|
||
not_missing_keys = [] | ||
for name, module in model.named_modules(): | ||
if isinstance(module, HiggsLinear): | ||
for missing in missing_keys: | ||
if ( | ||
(name in missing or name in f"{prefix}.{missing}") | ||
and not missing.endswith(".weight") | ||
and not missing.endswith(".bias") | ||
): | ||
not_missing_keys.append(missing) | ||
return [k for k in missing_keys if k not in not_missing_keys] | ||
|
||
@property | ||
def is_trainable(self, model: Optional["PreTrainedModel"] = None): | ||
return False | ||
|
||
def is_serializable(self, safe_serialization=None): | ||
return True | ||
|
||
def check_quantized_param( | ||
self, | ||
model: "PreTrainedModel", | ||
param_value: "torch.Tensor", | ||
param_name: str, | ||
state_dict: Dict[str, Any], | ||
**kwargs, | ||
) -> bool: | ||
from ..integrations import HiggsLinear | ||
|
||
module, tensor_name = get_module_from_name(model, param_name) | ||
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16: | ||
# Add here check for loaded components' dtypes once serialization is implemented | ||
return True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean that serialization is not implemented yet ? so we can't save a quantized model and load it ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, serialization is fully functional. This message got copied with bnb code I borrowed and I forgot to remove it. |
||
else: | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,6 +102,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ | |
_bitsandbytes_available = _is_package_available("bitsandbytes") | ||
_eetq_available = _is_package_available("eetq") | ||
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu") | ||
try: | ||
_flute_available = package_exists = ( | ||
importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") == "0.2.6" | ||
) | ||
except importlib.metadata.PackageNotFoundError: | ||
_flute_available = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's do the version check in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
_galore_torch_available = _is_package_available("galore_torch") | ||
_lomo_available = _is_package_available("lomo_optim") | ||
_grokadamw_available = _is_package_available("grokadamw") | ||
|
@@ -126,6 +132,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ | |
_faiss_available = False | ||
_ftfy_available = _is_package_available("ftfy") | ||
_g2p_en_available = _is_package_available("g2p_en") | ||
_hadamard_available = _is_package_available("fast_hadamard_transform") | ||
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) | ||
_jieba_available = _is_package_available("jieba") | ||
_jinja_available = _is_package_available("jinja2") | ||
|
@@ -330,6 +337,10 @@ def is_torch_deterministic(): | |
return True | ||
|
||
|
||
def is_hadamard_available(): | ||
return _hadamard_available | ||
|
||
|
||
def is_hqq_available(min_version: str = HQQ_MIN_VERSION): | ||
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version) | ||
|
||
|
@@ -602,6 +613,10 @@ def is_flax_available(): | |
return _flax_available | ||
|
||
|
||
def is_flute_available(): | ||
return _flute_available | ||
|
||
|
||
def is_ftfy_available(): | ||
return _ftfy_available | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docker image will be deployed on an instance with cuda 11.8 but on the
flute
github I noticed you need to specifyhttps://flute-ai.github.io/whl/cu118
in that caseThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated.