Skip to content

Commit

Permalink
Support fp32 activation for quantizing llama2 to int8 activation and …
Browse files Browse the repository at this point in the history
…int4 weight (pytorch#2032)

Summary:
Pull Request resolved: pytorch#2032

Previously we only supported quantizing fp16 activations to int8.
This adds support for quantizing fp32 activations as well to enable testing.

Representation:
https://www.internalfb.com/intern/everpaste/?handle=GAoWXBlf5O8T4TMBAP8Cps4UiVx7bsIXAAAz

Reviewed By: digantdesai

Differential Revision: D54032932

fbshipit-source-id: baaddbd385985240689444041c5de33245c86dcf
  • Loading branch information
andrewor14 authored and facebook-github-bot committed Feb 23, 2024
1 parent 33306d3 commit 78ce089
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 35 deletions.
20 changes: 13 additions & 7 deletions examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ class DType(Enum):
fp32 = "fp32"
fp16 = "fp16"

def to_torch_dtype(self) -> torch.dtype:
mapping = {
DType.fp32: torch.float32,
DType.fp16: torch.float16,
}
if self not in mapping:
raise ValueError(f"Unsupported dtype {self}")
return mapping[self]


def load_llama_model(
*,
Expand Down Expand Up @@ -145,13 +154,10 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager":
assert not dtype_override or isinstance(
dtype_override, DType
), "Override dtype needs to be of type <DType>"
if dtype_override == DType.fp16 and self.dtype != DType.fp16:
logging.info("model.to torch.float16")
self.model = self.model.to(dtype=torch.float16)
self.dtype = dtype_override
elif dtype_override == DType.fp32 and self.dtype != DType.fp32:
logging.info("model.to torch.float32")
self.model = self.model.to(dtype=torch.float32)
if dtype_override is not None and dtype_override != self.dtype:
torch_dtype = dtype_override.to_torch_dtype()
logging.info(f"model.to {torch_dtype}")
self.model = self.model.to(dtype=torch_dtype)
self.dtype = dtype_override
return self

Expand Down
42 changes: 27 additions & 15 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import shlex
from functools import partial
from pathlib import Path
from typing import List
from typing import List, Optional

import pkg_resources
import torch
Expand Down Expand Up @@ -94,7 +94,11 @@ def check_embedding_byte_registered():
return quantizers


def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module:
def quantize(
model: torch.nn.Module,
qmode: str,
activation_dtype: Optional[DType],
) -> torch.nn.Module:
"""
Quantizes a model by converting all weights to int8.
Args:
Expand All @@ -103,14 +107,21 @@ def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module:
Returns:
A quantized model.
"""
if activation_dtype is not None:
torch_dtype = activation_dtype.to_torch_dtype()
else:
torch_dtype = torch.float16

if qmode == "int8":
model_int8 = WeightOnlyInt8QuantHandler(model)
model_int8_state_dict = model_int8.create_quantized_state_dict()
model_int8 = model_int8.convert_for_runtime()
model_int8.load_state_dict(model_int8_state_dict)
return model_int8
elif qmode == "int4":
model_int4 = Int8DynActInt4WeightQuantHandler(model)
model_int4 = Int8DynActInt4WeightQuantHandler(
model, activation_precision=torch_dtype
)
model_int4_state_dict = model_int4.create_quantized_state_dict()
model_int4 = model_int4.convert_for_runtime()
print("quantized model:", model_int4)
Expand Down Expand Up @@ -269,28 +280,29 @@ def _export_llama(modelname, args) -> str: # noqa: C901
output_dir_path = canonical_path(args.output_dir, dir=True)
modelname = "llama2"
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA

# dtype override
if args.dtype_override is not None:
dtype_override = DType[args.dtype_override]
else:
dtype_override = DType["fp16"] if args.quantization_mode == "int4" else None

# source transforms
transforms = []
if args.quantized_ckpt or args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(partial(quantize, qmode=args.quantization_mode))
transforms.append(
partial(
quantize, qmode=args.quantization_mode, activation_dtype=dtype_override
)
)

if args.embedding_quantize:
modelname = f"{modelname}_e"
transforms.append(
lambda model: EmbeddingOnlyInt8QuantHandler(model).convert_for_runtime()
)

# dtype override
if args.dtype_override:
override = (
DType["fp16"]
if args.quantization_mode == "int4"
else DType[args.dtype_override]
)
else:
override = None

# export_to_edge
quantizers = get_pt2e_quantizers(args)

Expand Down Expand Up @@ -323,7 +335,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
.set_output_dir(output_dir_path)
.set_metadata(args.metadata)
.source_transform(transforms)
.to_dtype(override)
.to_dtype(dtype_override)
.export_to_edge(quantizers)
.to_backend(partitioners)
.to_executorch()
Expand Down
67 changes: 54 additions & 13 deletions examples/models/llama2/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,14 @@ def _calc_padded_size_linear_int4(k, groupsize=1, inner_k_tiles=1):
return find_multiple(k, groupsize, inner_k_tiles * 16)


def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed):
def replace_linear_8da4w(
module,
group_size,
inner_k_tiles,
padding_allowed,
activation_precision,
weight_precision,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if (
Expand All @@ -807,20 +814,37 @@ def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed):
bias=False,
group_size=group_size,
inner_k_tiles=inner_k_tiles,
activation_precision=activation_precision,
weight_precision=weight_precision,
),
)
else:
replace_linear_8da4w(child, group_size, inner_k_tiles, padding_allowed)
replace_linear_8da4w(
child,
group_size,
inner_k_tiles,
padding_allowed,
activation_precision,
weight_precision,
)


class Int8DynActInt4WeightQuantHandler:
def __init__(self, mod, group_size=128, inner_k_tiles=8, padding_allowed=True):
def __init__(
self,
mod,
group_size=128,
inner_k_tiles=8,
padding_allowed=True,
activation_precision=torch.float16,
weight_precision=torch.float16,
):
self.mod = mod
self.group_size = group_size
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
# TODO: make this an argument
self.precision = torch.float16
self.activation_precision = activation_precision
self.weight_precision = weight_precision
assert group_size in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]

Expand Down Expand Up @@ -861,7 +885,9 @@ def create_quantized_state_dict(self):
weight_int4pack,
scales_and_zeros,
) = prepare_int4_weight_and_scales_and_zeros(
weight.to(self.precision), self.group_size, self.inner_k_tiles
weight.to(self.weight_precision),
self.group_size,
self.inner_k_tiles,
)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu")
Expand All @@ -870,7 +896,12 @@ def create_quantized_state_dict(self):

def convert_for_runtime(self):
replace_linear_8da4w(
self.mod, self.group_size, self.inner_k_tiles, self.padding_allowed
self.mod,
self.group_size,
self.inner_k_tiles,
self.padding_allowed,
self.activation_precision,
self.weight_precision,
)
return self.mod

Expand All @@ -891,6 +922,8 @@ def __init__(
dtype=None,
group_size: int = 128,
inner_k_tiles: int = 8,
activation_precision: torch.dtype = torch.float16,
weight_precision: torch.dtype = torch.float16,
) -> None:
super().__init__()
# always pad if needed since it becomes a noop at runtime if not needed
Expand All @@ -903,7 +936,8 @@ def __init__(
assert not bias, "require bias=False"
self.group_size = group_size
self.inner_k_tiles = inner_k_tiles
self.precision = torch.float16
self.weight_precision = weight_precision
self.activation_precision = activation_precision

# assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
Expand All @@ -917,12 +951,13 @@ def __init__(
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // group_size, out_features, 2), dtype=self.precision
(in_features // group_size, out_features, 2),
dtype=self.weight_precision,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input.to(self.precision)
input = input.to(self.activation_precision)
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))

(
Expand All @@ -937,15 +972,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input, scales, zero_points, quant_min, quant_max, torch.int8
)
input = torch.ops.quantized_decomposed.dequantize_per_token(
input, scales, zero_points, quant_min, quant_max, torch.int8, self.precision
input,
scales,
zero_points,
quant_min,
quant_max,
torch.int8,
self.activation_precision,
)

input = input.to(self.precision)
input = input.to(self.activation_precision)
return linear_forward_int4(
input,
self.weight,
self.scales_and_zeros,
self.out_features,
self.group_size,
self.precision,
self.weight_precision,
)

0 comments on commit 78ce089

Please sign in to comment.