From 341545c0c11164d0381ac11871762cf0481ab60c Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 24 Sep 2024 23:03:52 -0700 Subject: [PATCH] add option to quantize output layer perchannel for SpinQuant (#5614) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5614 Add an option to optionally quantize the output layer int8 per-channel Reviewed By: mergennachin, iseeyuan Differential Revision: D62787491 fbshipit-source-id: cc86a9105966dddbdfb26f77c62a6e0f9c01d24c --- examples/models/llama2/export_llama_lib.py | 4 +- examples/models/llama2/model.py | 17 ++++- .../source_transformation/spin_quant.py | 65 ++++++++++++++++--- .../llama2/tests/test_spinquant_transforms.py | 54 ++++++++++++++- 4 files changed, 126 insertions(+), 14 deletions(-) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 6bf019de23..b0c4c4a23e 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -374,8 +374,8 @@ def build_args_parser() -> argparse.ArgumentParser: "--spin_qmode", type=str, default=None, - choices=["8da4w"], - help="Quantization mode for SpinQuant. Only support 8da4w right now.", + choices=["8da4w", "8da4w_output_8da8w"], + help="Quantization mode for SpinQuant. Only support 8da4w and 8da4w_output_8da8w right now.", ) parser.add_argument( diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 08effca2eb..8b0d7ceb3b 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -192,6 +192,10 @@ def __init__(self, **kwargs): elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: print("Using SPIN quantization.") assert hasattr(self.args, "spin_qmode"), "spin_qmode must be specified" + assert self.args.spin_qmode in [ + "8da4w", + "8da4w_output_8da8w", + ], f"Quantization mode {self.args.spin_qmode} is not compatible with SpinQuant." assert hasattr( self.args, "spin_group_size" ), "spin_group_size must be specified" @@ -209,11 +213,22 @@ def __init__(self, **kwargs): "bf16": torch.bfloat16, } + # Transform the output layer first if needed. + if self.args.spin_qmode == "8da4w_output_8da8w": + from .source_transformation.spin_quant import ( + transform_output_linear_for_spinquant, + ) + + self.model_ = transform_output_linear_for_spinquant( + module=self.model_, + checkpoint=checkpoint, + dtype=mapping[self.args.dtype_override], + ) + self.model_ = transform_linear_for_spinquant( self.model_, checkpoint, self.args.spin_group_size, - self.args.spin_qmode, mapping[self.args.dtype_override], ) diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py index b6107492d2..75f01fd6b3 100644 --- a/examples/models/llama2/source_transformation/spin_quant.py +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -20,7 +20,7 @@ from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from .quantize import QuantizedGroupEmbedding +from .quantize import Int8DynActInt8WeightLinear, QuantizedGroupEmbedding def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): @@ -129,20 +129,16 @@ def transform_linear_for_spinquant( module: torch.nn.Module, checkpoint: Any, group_size: int, - quantization_mode: str, dtype: torch.dtype, ) -> torch.nn.Module: """ Transform the model to be able to load SpinQuant checkpoints that - are quantized with the given group size and quantization mode. + are quantized with the given group size and quantization mode for + linear layers. """ if group_size not in [32, 64, 128, 256]: raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") - if quantization_mode not in ["8da4w"]: - raise ValueError( - f"Quantization mode {quantization_mode} is not compatible with SpinQuant." - ) _replace_linear_with_linear_8da4w_for_spin_quant( module, checkpoint, @@ -153,6 +149,53 @@ def transform_linear_for_spinquant( return module +def _replace_output_linear_with_linear_int8_for_spinquant( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + scales_key = f"{cur_fqn}.scale" + if ( + isinstance(child, nn.Linear) + and scales_key in checkpoint + and "output" in cur_fqn + ): + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == dtype + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt8WeightLinear( + device=child.weight.device, + in_features=child.in_features, + out_features=child.out_features, + precision=dtype, + bias=False, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_output_linear_for_spinquant( + module: torch.nn.Module, + checkpoint: Any, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Transform the model to be able to load SpinQuant checkpoints that + has the output layer quantized per-channel. + """ + _replace_output_linear_with_linear_int8_for_spinquant( + module, + checkpoint, + dtype, + ) + return module + + def _replace_embedding_with_quantized_group_embedding_for_spinquant( module: torch.nn.Module, checkpoint: Any, @@ -233,8 +276,10 @@ def sanitize_checkpoint_from_spinquant( module_name = new_key[0 : new_key.rfind(".")] sub_module = module.get_submodule(module_name) assert sub_module is not None - assert isinstance(sub_module, Int8DynActInt4WeightLinear) or isinstance( - sub_module, QuantizedGroupEmbedding + assert ( + isinstance(sub_module, Int8DynActInt4WeightLinear) + or isinstance(sub_module, QuantizedGroupEmbedding) + or isinstance(sub_module, Int8DynActInt8WeightLinear) ) # Checkpoints with SpinQuant could come with two formats for scales: # 1. scales is grouped by group size @@ -245,6 +290,8 @@ def sanitize_checkpoint_from_spinquant( checkpoint[new_key] = ( old_val if linear_group_size == -1 else old_val[:, ::linear_group_size] ) + elif isinstance(sub_module, Int8DynActInt8WeightLinear): + checkpoint[new_key] = old_val[:, 0] elif isinstance(sub_module, QuantizedGroupEmbedding): if ( embedding_group_size is None or embedding_group_size == 0 diff --git a/examples/models/llama2/tests/test_spinquant_transforms.py b/examples/models/llama2/tests/test_spinquant_transforms.py index 745bd6f46a..f7d9a2627e 100644 --- a/examples/models/llama2/tests/test_spinquant_transforms.py +++ b/examples/models/llama2/tests/test_spinquant_transforms.py @@ -8,10 +8,14 @@ import torch from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama2.source_transformation.quantize import ( + dynamically_quantize_per_channel, +) from executorch.examples.models.llama2.source_transformation.spin_quant import ( sanitize_checkpoint_from_spinquant, transform_embedding_for_spinquant, transform_linear_for_spinquant, + transform_output_linear_for_spinquant, ) from torchao.quantization.utils import group_quantize_tensor_symmetric @@ -51,8 +55,7 @@ def test_transform_linear_for_spinquant(self): n_bit = 4 scales_precision = torch.float32 for fqn, mod in model.named_modules(): - # Quantize everything except the last layer - if isinstance(mod, torch.nn.Linear) and ("output" not in fqn): + if isinstance(mod, torch.nn.Linear): weight = mod.weight.data ( weight_int8, @@ -92,6 +95,53 @@ def test_transform_linear_for_spinquant(self): # have to iterate over the keys. self.assertTrue(torch.allclose(new_checkpoint[k], v)) + def test_transform_output_linear_for_spinquant(self): + # Step 1: Create llama class with dummy weights + model = self._prepare_dummy_model() + checkpoint = model.state_dict() + + # Step 2: + # Do per-channel quantization and amend the checkpoints with + # int8 weight and fp32 scales + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear) and fqn == "output": + weight = mod.weight.data + weight_int8, scales, _ = dynamically_quantize_per_channel( + weight, + quant_min=-128, + quant_max=127, + target_dtype=torch.int8, + scales_dtype=torch.float32, + ) + checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu") + checkpoint[f"{fqn}.scale"] = scales.to("cpu") + + # Step 3: + # Transform the model so that it is compatible with the new checkpoint + transform_output_linear_for_spinquant( + model, + checkpoint, + torch.float32, + ) + sanitize_checkpoint_from_spinquant( + model, + checkpoint, + -1, + ) + + model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + + new_checkpoint = model.state_dict() + + for k, v in checkpoint.items(): + # The new_checkpoint contains zeros so + # have to iterate over the keys. + self.assertTrue(torch.allclose(new_checkpoint[k], v)) + def test_transform_embedding_for_spinquant(self): # Step 1: Create llama class with dummy weights