forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
spinquant in eager mode (pytorch#5125)
Summary: Pull Request resolved: pytorch#5125 This PR adds the option to export the model with spin quant on gpu. Reviewed By: mergennachin Differential Revision: D62042861 fbshipit-source-id: 74274fcb3408e5f6b23e0c924272385090da03d2
- Loading branch information
1 parent
69aed24
commit 41bc1ce
Showing
3 changed files
with
124 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
examples/models/llama2/source_transformation/spin_quant.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-unsafe | ||
|
||
# Helper functions for tranforming the model to be able to run SpinQuant. | ||
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. | ||
|
||
import torch | ||
|
||
import torch.nn.functional as F | ||
|
||
from executorch.examples.models.llama2.llama_transformer import FeedForward | ||
from torch import nn | ||
|
||
|
||
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): | ||
""" | ||
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer. | ||
R3 needs to be injected as well when KV cache quantization is enabled. | ||
""" | ||
try: | ||
from fast_hadamard_transform import hadamard_transform | ||
except ImportError: | ||
raise ImportError( | ||
"Please install fast-hadamard-transform: pip install fast-hadamard-transform" | ||
) | ||
|
||
class FeedForwardCustom(nn.Module): | ||
def __init__(self, w1, w2, w3): | ||
super().__init__() | ||
self.w1 = w1 | ||
self.w2 = w2 | ||
self.w3 = w3 | ||
|
||
def forward(self, x): | ||
w = F.silu(self.w1(x)) * self.w3(x) | ||
n = w.shape[-1] | ||
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt()) | ||
|
||
for name, child in module.named_children(): | ||
if isinstance(child, FeedForward): | ||
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3)) | ||
else: | ||
_inject_fast_hadamard_transform_cuda_for_spin_quant(child) | ||
|
||
|
||
def inject_fast_hadamard_transform_cuda_for_spin_quant( | ||
module: torch.nn.Module, | ||
) -> torch.nn.Module: | ||
_inject_fast_hadamard_transform_cuda_for_spin_quant(module) | ||
return module |