-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HIGGS Quantization Support #34997
base: main
Are you sure you want to change the base?
HIGGS Quantization Support #34997
Conversation
cc @MekkCyber |
Failed tests look like a problem on the runner's end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for integrating this new quantization method so fast! I left some comments and don't forget to also update the documentation so that the users knows how to use it !
if weight.device.type != "cuda": | ||
raise ValueError( | ||
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not necessary to put it here. The check on device_map when we initialize the quantizer would be enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
else: | ||
raise NotImplementedError( | ||
"HIGGS quantization is only supported on GPU. Please use a different quantizer." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's check if cuda is available in validate_environment
instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk( | ||
device=module.weight.device | ||
) | ||
module.workspace = flute_workspaces[module.weight.device] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add a comment on what we are doing here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comments to this and possible repacking happening afterwards
self.bits = bits | ||
self.p = p | ||
self.linear_weights_not_to_quantize = linear_weights_not_to_quantize | ||
self.num_sms_packed = 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a description of what this is used for ? The user shouldn't have to worry about that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the docstring to better reflect what those are
def post_init(self): | ||
r""" | ||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. | ||
""" | ||
return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add in post_init checks for bits and p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# @require_torch_gpu | ||
# class HiggsConfigTest(unittest.TestCase): | ||
# def test_to_dict(self): | ||
# """ | ||
# Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object | ||
# """ | ||
# quantization_config = HiggsConfig() | ||
# config_to_dict = quantization_config.to_dict() | ||
|
||
# for key in config_to_dict: | ||
# self.assertEqual(getattr(quantization_config, key), config_to_dict[key]) | ||
|
||
# def test_from_dict(self): | ||
# """ | ||
# Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict | ||
# """ | ||
# dict = {"linear_weights_not_to_quantize": ["embed_tokens.weight", "lm_head.weight"], "quant_method": "higgs"} | ||
# quantization_config = HiggsConfig.from_dict(dict) | ||
|
||
# self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize) | ||
# self.assertEqual(dict["quant_method"], quantization_config.quant_method) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove or uncomment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncommented this
@require_accelerate | ||
# @require_read_token | ||
class HiggsTest(unittest.TestCase): | ||
model_name = "meta-llama/Meta-Llama-3.1-8B" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use a smaller model like a tiny llama ? This will be better for our CI thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly, no. FLUTE is only compiled for specific matrix shapes, for now.
TinyLlama is not among those shape. Nor is any model smaller than 8B.
offload_device_map = { | ||
"model.embed_tokens": 0, | ||
"model.layers.0": 0, | ||
"model.layers.1": 0, | ||
"model.layers.2": 0, | ||
"model.layers.3": 0, | ||
"model.layers.4": 0, | ||
"model.layers.5": 0, | ||
"model.layers.6": 0, | ||
"model.layers.7": 0, | ||
"model.layers.8": 0, | ||
"model.layers.9": 0, | ||
"model.layers.10": 0, | ||
"model.layers.11": 0, | ||
"model.layers.12": 0, | ||
"model.layers.13": 0, | ||
"model.layers.14": 0, | ||
"model.layers.15": 0, | ||
"model.layers.16": "cpu", | ||
"model.layers.17": "cpu", | ||
"model.layers.18": "cpu", | ||
"model.layers.19": "cpu", | ||
"model.layers.20": "disk", | ||
"model.layers.21": "disk", | ||
"model.layers.22": "disk", | ||
"model.layers.23": "disk", | ||
"model.layers.24": "disk", | ||
"model.layers.25": "disk", | ||
"model.layers.26": "disk", | ||
"model.layers.27": "disk", | ||
"model.layers.28": "disk", | ||
"model.layers.29": "disk", | ||
"model.layers.30": "disk", | ||
"model.layers.31": "disk", | ||
"model.norm": "disk", | ||
"lm_head": "disk", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove or perform some tests with this device_map. I think we shouldn't allow users to pass this kind of device_map and some check should be added in validate_environment
. Check for example the awq integration code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed this from tests, added device map assertions to validate_environment
Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: Marc Sun <[email protected]>
Hey @BlackSamorez, Thanks for adding this quantization method so quickly ! I added some very small nits |
@@ -66,6 +66,10 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto | |||
# Add eetq for quantization testing | |||
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git | |||
|
|||
# Add flute-kernel and fast_hadamard_transform for quantization testing | |||
RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docker image will be deployed on an instance with cuda 11.8 but on the flute
github I noticed you need to specify https://flute-ai.github.io/whl/cu118
in that case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated.
class HiggsHfQuantizer(HfQuantizer): | ||
""" | ||
Quantizer of the HIGGS method. Enables the loading of prequantized models. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a small nit, I think we should specify that it enables both loading and quantization of models because there are other quantizers that only enable loading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added and in-flight quantization of full-precision models
.
module.num_sms_packed = torch.nn.Parameter( | ||
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32), | ||
requires_grad=False, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding why do we need the num_sms_packed
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codes packing is sms
dependent. We need to remember what was the sms
of the machine on which the codes were packed on (num_sms_packed
) to be able to check if we need to repack or not. Moreover, we need num_sms_packed
to do the repacking itself.
num_bits=module.num_bits, | ||
group_size=256, | ||
num_sms_packed=module.num_sms_packed.item(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small question, is the group_size
a constant ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there are a few hard-coded constants right now, including the group size. I think I will do a small refactoring to spell them out more explicitly.
module, tensor_name = get_module_from_name(model, param_name) | ||
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16: | ||
# Add here check for loaded components' dtypes once serialization is implemented | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean that serialization is not implemented yet ? so we can't save a quantized model and load it ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, serialization is fully functional. This message got copied with bnb code I borrowed and I forgot to remove it.
By the way, bnb implemented serialization quite some time ago as well.
nb_fbgemm_linear = 0 | ||
for module in model.modules(): | ||
if isinstance(module, HiggsLinear): | ||
nb_fbgemm_linear += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant nb_higgs_linear
😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Fixed
for m in module_tree: | ||
parent = parent._modules[m] | ||
return parent | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry if i'm mistaken, I don't believe we use this function anywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the unused function. Thanks!
Co-authored-by: Mohamed Mekkouri <[email protected]>
@SunMarc @MekkCyber thanks for your feedback! |
HIGGS 0-Shot Quantization
HIGGS is a new 0-shot quantization algorithm that combines Hadamard preprocessing with MSE-Optimal quantization grids to achieve lower quantization error and SOTA performance. You can find more information in the paper.
Runtime support for HIGGS is implemented through FLUTE, and its library.
This PR adds support for HIGGS+FLUTE into
transformers
allowing for low-error 0-shot quantization and fast LLM inference.Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.