From 41bc1ce4c0e0f8c341fa7e7738b6210519d880d9 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Tue, 10 Sep 2024 18:26:36 -0700 Subject: [PATCH] spinquant in eager mode (#5125) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5125 This PR adds the option to export the model with spin quant on gpu. Reviewed By: mergennachin Differential Revision: D62042861 fbshipit-source-id: 74274fcb3408e5f6b23e0c924272385090da03d2 --- examples/models/llama2/TARGETS | 2 + examples/models/llama2/export_llama_lib.py | 109 +++++++++++------- .../source_transformation/spin_quant.py | 55 +++++++++ 3 files changed, 124 insertions(+), 42 deletions(-) create mode 100644 examples/models/llama2/source_transformation/spin_quant.py diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index ae3e1e00f9..f1c56a5bda 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -75,6 +75,7 @@ runtime.python_library( "source_transformation/rms_norm.py", "source_transformation/rope.py", "source_transformation/sdpa.py", + "source_transformation/spin_quant.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama2", @@ -85,6 +86,7 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 611bf16428..dd5822c23f 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -16,7 +16,7 @@ from enum import Enum from json import JSONDecodeError from pathlib import Path -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import pkg_resources @@ -340,6 +340,15 @@ def build_args_parser() -> argparse.ArgumentParser: required=False, default="SM8650", ) + + parser.add_argument( + "-sq", + "--use_spin_quant", + type=str, + default=None, + choices=["cuda", "native"], + help="Use SpinQuant for better quantization performance. Only support cuda and native.", + ) return parser @@ -411,46 +420,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: else: dtype_override = None - # source transforms - transforms = [] - if args.quantization_mode: - modelname = f"{modelname}_q" - transforms.append( - get_quant_weight_transform(args, dtype_override, verbose_export()) - ) - - if args.embedding_quantize: - modelname = f"{modelname}_e" - transforms.append(get_quant_embedding_transform(args)) - - if args.expand_rope_table: - transforms.append(materialze_broadcast_of_rope_freq_cis) - - if args.use_sdpa_with_kv_cache: - transforms.append(replace_sdpa_with_custom_op) - - if args.use_kv_cache: - if args.qnn: - # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` - from executorch.backends.qualcomm.utils.utils import ( - convert_linear_to_conv2d, - ) - - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - transforms.append(replace_rms_norm_with_native_rms_norm) - if args.optimized_rotation_path: - transforms.append(fuse_layer_norms) - transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) - transforms.append(convert_linear_to_conv2d) - - elif args.coreml or args.mps: - # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition - # to get free perf gain. - transforms.append(replace_sdpa_with_simple_sdpa) - transforms.append(replace_causal_mask) - return ( _load_llama_model( modelname=modelname, @@ -474,7 +443,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: ) .set_output_dir(output_dir_path) .to_dtype(dtype_override) - .source_transform(transforms) + .source_transform(_get_source_transforms(modelname, dtype_override, args)) ) @@ -763,3 +732,59 @@ def _load_llama_model( ), args=args, ) + + +def _get_source_transforms( + modelname: str, dtype_override: Optional[DType], args +) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: + transforms = [] + if args.quantization_mode: + modelname = f"{modelname}_q" + transforms.append( + get_quant_weight_transform(args, dtype_override, verbose_export()) + ) + + if args.embedding_quantize: + modelname = f"{modelname}_e" + transforms.append(get_quant_embedding_transform(args)) + + if args.expand_rope_table: + transforms.append(materialze_broadcast_of_rope_freq_cis) + + if args.use_sdpa_with_kv_cache: + transforms.append(replace_sdpa_with_custom_op) + + if args.use_kv_cache: + if args.qnn: + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` + from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + ) + + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) + transforms.append(convert_linear_to_conv2d) + + elif args.coreml or args.mps: + # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition + # to get free perf gain. + transforms.append(replace_sdpa_with_simple_sdpa) + transforms.append(replace_causal_mask) + + if args.use_spin_quant: + if args.use_spin_quant == "cuda": + from .source_transformation.spin_quant import ( + inject_fast_hadamard_transform_cuda_for_spin_quant, + ) + + transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) + + elif args.use_spin_quant == "native": + raise NotImplementedError("native SpinQuant is not implemented yet.") + + return transforms diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py new file mode 100644 index 0000000000..7b38312c18 --- /dev/null +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Helper functions for tranforming the model to be able to run SpinQuant. +# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. + +import torch + +import torch.nn.functional as F + +from executorch.examples.models.llama2.llama_transformer import FeedForward +from torch import nn + + +def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): + """ + SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer. + R3 needs to be injected as well when KV cache quantization is enabled. + """ + try: + from fast_hadamard_transform import hadamard_transform + except ImportError: + raise ImportError( + "Please install fast-hadamard-transform: pip install fast-hadamard-transform" + ) + + class FeedForwardCustom(nn.Module): + def __init__(self, w1, w2, w3): + super().__init__() + self.w1 = w1 + self.w2 = w2 + self.w3 = w3 + + def forward(self, x): + w = F.silu(self.w1(x)) * self.w3(x) + n = w.shape[-1] + return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt()) + + for name, child in module.named_children(): + if isinstance(child, FeedForward): + setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3)) + else: + _inject_fast_hadamard_transform_cuda_for_spin_quant(child) + + +def inject_fast_hadamard_transform_cuda_for_spin_quant( + module: torch.nn.Module, +) -> torch.nn.Module: + _inject_fast_hadamard_transform_cuda_for_spin_quant(module) + return module