Skip to content

Commit

Permalink
spinquant in eager mode (pytorch#5125)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#5125

This PR adds the option to export the model with spin quant on gpu.

Reviewed By: mergennachin

Differential Revision: D62042861

fbshipit-source-id: 74274fcb3408e5f6b23e0c924272385090da03d2
  • Loading branch information
Lunwen He authored and facebook-github-bot committed Sep 11, 2024
1 parent 69aed24 commit 41bc1ce
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 42 deletions.
2 changes: 2 additions & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ runtime.python_library(
"source_transformation/rms_norm.py",
"source_transformation/rope.py",
"source_transformation/sdpa.py",
"source_transformation/spin_quant.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama2",
Expand All @@ -85,6 +86,7 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
Expand Down
109 changes: 67 additions & 42 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import pkg_resources

Expand Down Expand Up @@ -340,6 +340,15 @@ def build_args_parser() -> argparse.ArgumentParser:
required=False,
default="SM8650",
)

parser.add_argument(
"-sq",
"--use_spin_quant",
type=str,
default=None,
choices=["cuda", "native"],
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
)
return parser


Expand Down Expand Up @@ -411,46 +420,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
else:
dtype_override = None

# source transforms
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)

if args.embedding_quantize:
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args))

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_sdpa_with_custom_op)

if args.use_kv_cache:
if args.qnn:
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import (
convert_linear_to_conv2d,
)

transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
transforms.append(convert_linear_to_conv2d)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

return (
_load_llama_model(
modelname=modelname,
Expand All @@ -474,7 +443,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
)
.set_output_dir(output_dir_path)
.to_dtype(dtype_override)
.source_transform(transforms)
.source_transform(_get_source_transforms(modelname, dtype_override, args))
)


Expand Down Expand Up @@ -763,3 +732,59 @@ def _load_llama_model(
),
args=args,
)


def _get_source_transforms(
modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)

if args.embedding_quantize:
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args))

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_sdpa_with_custom_op)

if args.use_kv_cache:
if args.qnn:
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import (
convert_linear_to_conv2d,
)

transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
transforms.append(convert_linear_to_conv2d)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)

elif args.use_spin_quant == "native":
raise NotImplementedError("native SpinQuant is not implemented yet.")

return transforms
55 changes: 55 additions & 0 deletions examples/models/llama2/source_transformation/spin_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Helper functions for tranforming the model to be able to run SpinQuant.
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.

import torch

import torch.nn.functional as F

from executorch.examples.models.llama2.llama_transformer import FeedForward
from torch import nn


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
"""
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
R3 needs to be injected as well when KV cache quantization is enabled.
"""
try:
from fast_hadamard_transform import hadamard_transform
except ImportError:
raise ImportError(
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
)

class FeedForwardCustom(nn.Module):
def __init__(self, w1, w2, w3):
super().__init__()
self.w1 = w1
self.w2 = w2
self.w3 = w3

def forward(self, x):
w = F.silu(self.w1(x)) * self.w3(x)
n = w.shape[-1]
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())

for name, child in module.named_children():
if isinstance(child, FeedForward):
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
else:
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)


def inject_fast_hadamard_transform_cuda_for_spin_quant(
module: torch.nn.Module,
) -> torch.nn.Module:
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
return module

0 comments on commit 41bc1ce

Please sign in to comment.