Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HIGGS Quantization Support #34997

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open

Conversation

BlackSamorez
Copy link
Contributor

@BlackSamorez BlackSamorez commented Nov 28, 2024

HIGGS 0-Shot Quantization

HIGGS is a new 0-shot quantization algorithm that combines Hadamard preprocessing with MSE-Optimal quantization grids to achieve lower quantization error and SOTA performance. You can find more information in the paper.

Runtime support for HIGGS is implemented through FLUTE, and its library.

This PR adds support for HIGGS+FLUTE into transformers allowing for low-error 0-shot quantization and fast LLM inference.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Member

cc @SunMarc @MekkCyber

@SunMarc
Copy link
Member

SunMarc commented Nov 28, 2024

cc @MekkCyber

@BlackSamorez
Copy link
Contributor Author

Failed tests look like a problem on the runner's end

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for integrating this new quantization method so fast! I left some comments and don't forget to also update the documentation so that the users knows how to use it !

Comment on lines 320 to 323
if weight.device.type != "cuda":
raise ValueError(
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary to put it here. The check on device_map when we initialize the quantizer would be enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

src/transformers/integrations/higgs.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_higgs.py Outdated Show resolved Hide resolved
Comment on lines 75 to 78
else:
raise NotImplementedError(
"HIGGS quantization is only supported on GPU. Please use a different quantizer."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's check if cuda is available in validate_environment instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +137 to +140
flute_workspaces[module.weight.device] = flute.utils.make_workspace_streamk(
device=module.weight.device
)
module.workspace = flute_workspaces[module.weight.device]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment on what we are doing here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments to this and possible repacking happening afterwards

self.bits = bits
self.p = p
self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
self.num_sms_packed = 128
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a description of what this is used for ? The user shouldn't have to worry about that ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the docstring to better reflect what those are

Comment on lines 1267 to 1272
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
return

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add in post_init checks for bits and p

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 39 to 60
# @require_torch_gpu
# class HiggsConfigTest(unittest.TestCase):
# def test_to_dict(self):
# """
# Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
# """
# quantization_config = HiggsConfig()
# config_to_dict = quantization_config.to_dict()

# for key in config_to_dict:
# self.assertEqual(getattr(quantization_config, key), config_to_dict[key])

# def test_from_dict(self):
# """
# Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
# """
# dict = {"linear_weights_not_to_quantize": ["embed_tokens.weight", "lm_head.weight"], "quant_method": "higgs"}
# quantization_config = HiggsConfig.from_dict(dict)

# self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize)
# self.assertEqual(dict["quant_method"], quantization_config.quant_method)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove or uncomment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncommented this

@require_accelerate
# @require_read_token
class HiggsTest(unittest.TestCase):
model_name = "meta-llama/Meta-Llama-3.1-8B"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use a smaller model like a tiny llama ? This will be better for our CI thanks !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, no. FLUTE is only compiled for specific matrix shapes, for now.
TinyLlama is not among those shape. Nor is any model smaller than 8B.

Comment on lines 77 to 113
offload_device_map = {
"model.embed_tokens": 0,
"model.layers.0": 0,
"model.layers.1": 0,
"model.layers.2": 0,
"model.layers.3": 0,
"model.layers.4": 0,
"model.layers.5": 0,
"model.layers.6": 0,
"model.layers.7": 0,
"model.layers.8": 0,
"model.layers.9": 0,
"model.layers.10": 0,
"model.layers.11": 0,
"model.layers.12": 0,
"model.layers.13": 0,
"model.layers.14": 0,
"model.layers.15": 0,
"model.layers.16": "cpu",
"model.layers.17": "cpu",
"model.layers.18": "cpu",
"model.layers.19": "cpu",
"model.layers.20": "disk",
"model.layers.21": "disk",
"model.layers.22": "disk",
"model.layers.23": "disk",
"model.layers.24": "disk",
"model.layers.25": "disk",
"model.layers.26": "disk",
"model.layers.27": "disk",
"model.layers.28": "disk",
"model.layers.29": "disk",
"model.layers.30": "disk",
"model.layers.31": "disk",
"model.norm": "disk",
"lm_head": "disk",
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove or perform some tests with this device_map. I think we shouldn't allow users to pass this kind of device_map and some check should be added in validate_environment. Check for example the awq integration code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this from tests, added device map assertions to validate_environment

@MekkCyber
Copy link
Contributor

Hey @BlackSamorez, Thanks for adding this quantization method so quickly ! I added some very small nits

@@ -66,6 +66,10 @@ RUN python3 -m pip install --no-cache-dir optimum-quanto
# Add eetq for quantization testing
RUN python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git

# Add flute-kernel and fast_hadamard_transform for quantization testing
RUN python3 -m pip install --no-cache-dir flute-kernel==0.2.6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docker image will be deployed on an instance with cuda 11.8 but on the flute github I noticed you need to specify https://flute-ai.github.io/whl/cu118 in that case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

class HiggsHfQuantizer(HfQuantizer):
"""
Quantizer of the HIGGS method. Enables the loading of prequantized models.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a small nit, I think we should specify that it enables both loading and quantization of models because there are other quantizers that only enable loading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added and in-flight quantization of full-precision models.

src/transformers/quantizers/quantizer_higgs.py Outdated Show resolved Hide resolved
module.num_sms_packed = torch.nn.Parameter(
torch.tensor(get_num_sms_from_device(target_device), device=target_device, dtype=torch.int32),
requires_grad=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding why do we need the num_sms_packed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codes packing is sms dependent. We need to remember what was the sms of the machine on which the codes were packed on (num_sms_packed) to be able to check if we need to repack or not. Moreover, we need num_sms_packed to do the repacking itself.

num_bits=module.num_bits,
group_size=256,
num_sms_packed=module.num_sms_packed.item(),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small question, is the group_size a constant ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are a few hard-coded constants right now, including the group size. I think I will do a small refactoring to spell them out more explicitly.

module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module, HiggsLinear) and tensor_name == "weight" and param_value.dtype != torch.int16:
# Add here check for loaded components' dtypes once serialization is implemented
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that serialization is not implemented yet ? so we can't save a quantized model and load it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, serialization is fully functional. This message got copied with bnb code I borrowed and I forgot to remove it.
By the way, bnb implemented serialization quite some time ago as well.

nb_fbgemm_linear = 0
for module in model.modules():
if isinstance(module, HiggsLinear):
nb_fbgemm_linear += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you meant nb_higgs_linear 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Fixed

for m in module_tree:
parent = parent._modules[m]
return parent

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry if i'm mistaken, I don't believe we use this function anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the unused function. Thanks!

@BlackSamorez
Copy link
Contributor Author

@SunMarc @MekkCyber thanks for your feedback!
I think I addressed all of your concerns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants