Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HIGGS Quantization Support #34997

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
08b347c
higgs init
BlackSamorez Nov 27, 2024
14a0c82
working with crunches
BlackSamorez Nov 27, 2024
1c5b9e7
per-model workspaces
BlackSamorez Nov 28, 2024
9f2ef77
style
BlackSamorez Nov 28, 2024
0ff58c3
style 2
BlackSamorez Nov 28, 2024
2e9adc6
Merge branch 'huggingface:main' into main
BlackSamorez Nov 28, 2024
b6bad71
tests and style
BlackSamorez Nov 28, 2024
c2bcf39
higgs tests passing
BlackSamorez Nov 28, 2024
a1e7b35
protecting torch import
BlackSamorez Nov 28, 2024
8f1a0a6
removed torch.Tensor type annotations
BlackSamorez Nov 28, 2024
120f360
torch.nn.Module inheritance fix maybe
BlackSamorez Nov 28, 2024
fdb71a5
hide inputs inside quantizer calls
BlackSamorez Nov 28, 2024
127c5f0
style structure something
BlackSamorez Nov 28, 2024
947a53d
Merge branch 'main' into main
BlackSamorez Nov 28, 2024
e6ddc41
Merge branch 'main' into main
BlackSamorez Nov 28, 2024
0de97f1
Update src/transformers/quantizers/quantizer_higgs.py
BlackSamorez Nov 28, 2024
1f08cb0
reworked num_sms
BlackSamorez Nov 28, 2024
ed369be
Merge branch 'main' of github.com:BlackSamorez/transformers
BlackSamorez Nov 28, 2024
96023ab
Update src/transformers/integrations/higgs.py
BlackSamorez Nov 28, 2024
60ce44b
revamped device checks
BlackSamorez Nov 29, 2024
8142443
docstring upd
BlackSamorez Nov 29, 2024
1d636ac
Update src/transformers/quantizers/quantizer_higgs.py
BlackSamorez Nov 29, 2024
66ece1d
edited tests and device map assertions
BlackSamorez Nov 29, 2024
5a68cd6
Merge branch 'main' of github.com:BlackSamorez/transformers
BlackSamorez Nov 29, 2024
1cb9f0c
minor edits
BlackSamorez Nov 29, 2024
257c39b
updated flute cuda version in docker
BlackSamorez Nov 29, 2024
f82d1a3
Added p=1 and 2,3bit HIGGS
BlackSamorez Nov 29, 2024
b747980
flute version check update
BlackSamorez Nov 29, 2024
398d5b1
incorporated `modules_to_not_convert`
BlackSamorez Nov 29, 2024
0ede69c
less hardcoding
BlackSamorez Nov 29, 2024
ebe6766
Merge branch 'main' into main
BlackSamorez Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docker/transformers-quantization-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto
# Add eetq for quantization testing
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git

# Add flute-kernel and fast_hadamard_transform for quantization testing
RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docker image will be deployed on an instance with cuda 11.8 but on the flute github I noticed you need to specify https://flute-ai.github.io/whl/cu118 in that case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

RUN python3 -m pip install --no-cache-dir fast_hadamard_transform==1.0.4.post1

# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop
4 changes: 4 additions & 0 deletions docs/source/en/main_classes/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.

[[autodoc]] quantizers.base.HfQuantizer

## HiggsConfig

[[autodoc]] HiggsConfig

## HqqConfig

[[autodoc]] HqqConfig
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@
"EetqConfig",
"FbgemmFp8Config",
"GPTQConfig",
"HiggsConfig",
"HqqConfig",
"QuantoConfig",
"TorchAoConfig",
Expand Down Expand Up @@ -5926,6 +5927,7 @@
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,
QuantoConfig,
TorchAoConfig,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"load_dequant_gguf_tensor",
"load_gguf",
],
"higgs": ["HiggsLinear", "quantize_with_higgs", "replace_with_higgs_linear"],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
Expand Down Expand Up @@ -165,6 +166,7 @@
load_dequant_gguf_tensor,
load_gguf,
)
from .higgs import HiggsLinear, quantize_with_higgs, replace_with_higgs_linear
from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
Expand Down
506 changes: 506 additions & 0 deletions src/transformers/integrations/higgs.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
EetqConfig,
FbgemmFp8Config,
GPTQConfig,
HiggsConfig,
HqqConfig,
QuantizationConfigMixin,
QuantizationMethod,
Expand All @@ -39,6 +40,7 @@
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_higgs import HiggsHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
from .quantizer_torchao import TorchAoHfQuantizer
Expand All @@ -52,6 +54,7 @@
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"higgs": HiggsHfQuantizer,
"hqq": HqqHfQuantizer,
"compressed-tensors": CompressedTensorsHfQuantizer,
"fbgemm_fp8": FbgemmFp8HfQuantizer,
Expand All @@ -70,6 +73,7 @@
"hqq": HqqConfig,
"compressed-tensors": CompressedTensorsConfig,
"fbgemm_fp8": FbgemmFp8Config,
"higgs": HiggsConfig,
"torchao": TorchAoConfig,
"bitnet": BitNetConfig,
}
Expand Down
179 changes: 179 additions & 0 deletions src/transformers/quantizers/quantizer_higgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from .base import HfQuantizer
from .quantizers_utils import get_module_from_name


if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel

from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin


if is_torch_available():
import torch

logger = logging.get_logger(__name__)


# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry if i'm mistaken, I don't believe we use this function anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the unused function. Thanks!


class HiggsHfQuantizer(HfQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a small nit, I think we should specify that it enables both loading and quantization of models because there are other quantizers that only enable loading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added and in-flight quantization of full-precision models.


requires_calibration = False
requires_parameters_quantization = True
required_packages = ["flute-kernel", "fast_hadamard_transform"]
optimum_quantizer = None
BlackSamorez marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config

def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`")

if not is_flute_available():
raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel`")

if not is_hadamard_available():
raise ImportError(
"Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`"
)

def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
if torch.cuda.is_available():
torch_dtype = torch.float16
logger.info(
"CUDA available. Assuming HIGGS inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
)
else:
raise NotImplementedError(
"HIGGS quantization is only supported on GPU. Please use a different quantizer."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's check if cuda is available in validate_environment instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return torch_dtype

def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
):
from ..integrations import quantize_with_higgs

"""
Quantizes weights into weight and weight_scale
"""
flute_dict = quantize_with_higgs(
param_value.to(target_device),
self.quantization_config.bits,
self.quantization_config.p,
)

del param_value

module, tensor_name = get_module_from_name(model, param_name)
for key, value in flute_dict.items():
if key in module._parameters:
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False)
elif key in module._buffers:
module._buffers[key] = torch.nn.Buffer(value)
else:
raise ValueError(f"Unexpected key {key} in module {module}")

if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)

def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
from ..integrations import replace_with_higgs_linear

replace_with_higgs_linear(
model,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config

def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
import flute.utils

from ..integrations import HiggsLinear

flute_workspaces = {}
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
if module.weight.device not in flute_workspaces:
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
device=module.weight.device
)
module.workspace = flute_workspaces[module.weight.device]
Comment on lines +155 to +158
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment on what we are doing here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments to this and possible repacking happening afterwards


def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import HiggsLinear

not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, HiggsLinear):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]

@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False

def is_serializable(self, safe_serialization=None):
return True

def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
from ..integrations import HiggsLinear

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
# Add here check for loaded components' dtypes once serialization is implemented
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that serialization is not implemented yet ? so we can't save a quantized model and load it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, serialization is fully functional. This message got copied with bnb code I borrowed and I forgot to remove it.
By the way, bnb implemented serialization quite some time ago as well.

else:
return False
9 changes: 9 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@
is_fbgemm_gpu_available,
is_flash_attn_2_available,
is_flax_available,
is_flute_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hadamard_available,
is_ipex_available,
is_jieba_available,
is_jinja_available,
Expand Down Expand Up @@ -1227,6 +1229,13 @@ def require_fbgemm_gpu(test_case):
return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)


def require_flute_hadamard(test_case):
"""
Decorator marking a test that requires higgs and hadamard
"""
return unittest.skipUnless(is_flute_available() and is_hadamard_available(), "test requires aqlm")(test_case)


def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,14 @@
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
is_flute_available,
is_fsdp_available,
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_gguf_available,
is_grokadamw_available,
is_hadamard_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_bitsandbytes_available = _is_package_available("bitsandbytes")
_eetq_available = _is_package_available("eetq")
_fbgemm_gpu_available = _is_package_available("fbgemm_gpu")
try:
_flute_available = package_exists = (
importlib.util.find_spec("flute") is not None and importlib.metadata.version("flute-kernel") == "0.2.6"
)
except importlib.metadata.PackageNotFoundError:
_flute_available = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do the version check in is_flute_available and you shouldn't pin the version like that. We should allow a minimum version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

_galore_torch_available = _is_package_available("galore_torch")
_lomo_available = _is_package_available("lomo_optim")
_grokadamw_available = _is_package_available("grokadamw")
Expand All @@ -126,6 +132,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
_faiss_available = False
_ftfy_available = _is_package_available("ftfy")
_g2p_en_available = _is_package_available("g2p_en")
_hadamard_available = _is_package_available("fast_hadamard_transform")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba")
_jinja_available = _is_package_available("jinja2")
Expand Down Expand Up @@ -330,6 +337,10 @@ def is_torch_deterministic():
return True


def is_hadamard_available():
return _hadamard_available


def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)

Expand Down Expand Up @@ -602,6 +613,10 @@ def is_flax_available():
return _flax_available


def is_flute_available():
return _flute_available


def is_ftfy_available():
return _ftfy_available

Expand Down
Loading