diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index ec3a139147..ea9cf39741 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -46,6 +46,15 @@ class DType(Enum): fp32 = "fp32" fp16 = "fp16" + def to_torch_dtype(self) -> torch.dtype: + mapping = { + DType.fp32: torch.float32, + DType.fp16: torch.float16, + } + if self not in mapping: + raise ValueError(f"Unsupported dtype {self}") + return mapping[self] + def load_llama_model( *, @@ -145,13 +154,10 @@ def to_dtype(self, dtype_override: Optional[DType]) -> "LlamaEdgeManager": assert not dtype_override or isinstance( dtype_override, DType ), "Override dtype needs to be of type " - if dtype_override == DType.fp16 and self.dtype != DType.fp16: - logging.info("model.to torch.float16") - self.model = self.model.to(dtype=torch.float16) - self.dtype = dtype_override - elif dtype_override == DType.fp32 and self.dtype != DType.fp32: - logging.info("model.to torch.float32") - self.model = self.model.to(dtype=torch.float32) + if dtype_override is not None and dtype_override != self.dtype: + torch_dtype = dtype_override.to_torch_dtype() + logging.info(f"model.to {torch_dtype}") + self.model = self.model.to(dtype=torch_dtype) self.dtype = dtype_override return self diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index c79a85a9ec..0a584a0361 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -11,7 +11,7 @@ import shlex from functools import partial from pathlib import Path -from typing import List +from typing import List, Optional import pkg_resources import torch @@ -94,7 +94,11 @@ def check_embedding_byte_registered(): return quantizers -def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module: +def quantize( + model: torch.nn.Module, + qmode: str, + activation_dtype: Optional[DType], +) -> torch.nn.Module: """ Quantizes a model by converting all weights to int8. Args: @@ -103,6 +107,11 @@ def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module: Returns: A quantized model. """ + if activation_dtype is not None: + torch_dtype = activation_dtype.to_torch_dtype() + else: + torch_dtype = torch.float16 + if qmode == "int8": model_int8 = WeightOnlyInt8QuantHandler(model) model_int8_state_dict = model_int8.create_quantized_state_dict() @@ -110,7 +119,9 @@ def quantize(model: torch.nn.Module, qmode: str) -> torch.nn.Module: model_int8.load_state_dict(model_int8_state_dict) return model_int8 elif qmode == "int4": - model_int4 = Int8DynActInt4WeightQuantHandler(model) + model_int4 = Int8DynActInt4WeightQuantHandler( + model, activation_precision=torch_dtype + ) model_int4_state_dict = model_int4.create_quantized_state_dict() model_int4 = model_int4.convert_for_runtime() print("quantized model:", model_int4) @@ -269,11 +280,22 @@ def _export_llama(modelname, args) -> str: # noqa: C901 output_dir_path = canonical_path(args.output_dir, dir=True) modelname = "llama2" weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA + + # dtype override + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + else: + dtype_override = DType["fp16"] if args.quantization_mode == "int4" else None + # source transforms transforms = [] if args.quantized_ckpt or args.quantization_mode: modelname = f"{modelname}_q" - transforms.append(partial(quantize, qmode=args.quantization_mode)) + transforms.append( + partial( + quantize, qmode=args.quantization_mode, activation_dtype=dtype_override + ) + ) if args.embedding_quantize: modelname = f"{modelname}_e" @@ -281,16 +303,6 @@ def _export_llama(modelname, args) -> str: # noqa: C901 lambda model: EmbeddingOnlyInt8QuantHandler(model).convert_for_runtime() ) - # dtype override - if args.dtype_override: - override = ( - DType["fp16"] - if args.quantization_mode == "int4" - else DType[args.dtype_override] - ) - else: - override = None - # export_to_edge quantizers = get_pt2e_quantizers(args) @@ -323,7 +335,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901 .set_output_dir(output_dir_path) .set_metadata(args.metadata) .source_transform(transforms) - .to_dtype(override) + .to_dtype(dtype_override) .export_to_edge(quantizers) .to_backend(partitioners) .to_executorch() diff --git a/examples/models/llama2/quantize.py b/examples/models/llama2/quantize.py index b87c014ea6..7592eebca7 100644 --- a/examples/models/llama2/quantize.py +++ b/examples/models/llama2/quantize.py @@ -791,7 +791,14 @@ def _calc_padded_size_linear_int4(k, groupsize=1, inner_k_tiles=1): return find_multiple(k, groupsize, inner_k_tiles * 16) -def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed): +def replace_linear_8da4w( + module, + group_size, + inner_k_tiles, + padding_allowed, + activation_precision, + weight_precision, +): for name, child in module.named_children(): if isinstance(child, nn.Linear): if ( @@ -807,20 +814,37 @@ def replace_linear_8da4w(module, group_size, inner_k_tiles, padding_allowed): bias=False, group_size=group_size, inner_k_tiles=inner_k_tiles, + activation_precision=activation_precision, + weight_precision=weight_precision, ), ) else: - replace_linear_8da4w(child, group_size, inner_k_tiles, padding_allowed) + replace_linear_8da4w( + child, + group_size, + inner_k_tiles, + padding_allowed, + activation_precision, + weight_precision, + ) class Int8DynActInt4WeightQuantHandler: - def __init__(self, mod, group_size=128, inner_k_tiles=8, padding_allowed=True): + def __init__( + self, + mod, + group_size=128, + inner_k_tiles=8, + padding_allowed=True, + activation_precision=torch.float16, + weight_precision=torch.float16, + ): self.mod = mod self.group_size = group_size self.inner_k_tiles = inner_k_tiles self.padding_allowed = padding_allowed - # TODO: make this an argument - self.precision = torch.float16 + self.activation_precision = activation_precision + self.weight_precision = weight_precision assert group_size in [32, 64, 128, 256] assert inner_k_tiles in [2, 4, 8] @@ -861,7 +885,9 @@ def create_quantized_state_dict(self): weight_int4pack, scales_and_zeros, ) = prepare_int4_weight_and_scales_and_zeros( - weight.to(self.precision), self.group_size, self.inner_k_tiles + weight.to(self.weight_precision), + self.group_size, + self.inner_k_tiles, ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") @@ -870,7 +896,12 @@ def create_quantized_state_dict(self): def convert_for_runtime(self): replace_linear_8da4w( - self.mod, self.group_size, self.inner_k_tiles, self.padding_allowed + self.mod, + self.group_size, + self.inner_k_tiles, + self.padding_allowed, + self.activation_precision, + self.weight_precision, ) return self.mod @@ -891,6 +922,8 @@ def __init__( dtype=None, group_size: int = 128, inner_k_tiles: int = 8, + activation_precision: torch.dtype = torch.float16, + weight_precision: torch.dtype = torch.float16, ) -> None: super().__init__() # always pad if needed since it becomes a noop at runtime if not needed @@ -903,7 +936,8 @@ def __init__( assert not bias, "require bias=False" self.group_size = group_size self.inner_k_tiles = inner_k_tiles - self.precision = torch.float16 + self.weight_precision = weight_precision + self.activation_precision = activation_precision # assert out_features % 8 == 0, "require out_features % 8 == 0" assert ( @@ -917,12 +951,13 @@ def __init__( self.register_buffer( "scales_and_zeros", torch.empty( - (in_features // group_size, out_features, 2), dtype=self.precision + (in_features // group_size, out_features, 2), + dtype=self.weight_precision, ), ) def forward(self, input: torch.Tensor) -> torch.Tensor: - input = input.to(self.precision) + input = input.to(self.activation_precision) input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) ( @@ -937,15 +972,21 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, scales, zero_points, quant_min, quant_max, torch.int8 ) input = torch.ops.quantized_decomposed.dequantize_per_token( - input, scales, zero_points, quant_min, quant_max, torch.int8, self.precision + input, + scales, + zero_points, + quant_min, + quant_max, + torch.int8, + self.activation_precision, ) - input = input.to(self.precision) + input = input.to(self.activation_precision) return linear_forward_int4( input, self.weight, self.scales_and_zeros, self.out_features, self.group_size, - self.precision, + self.weight_precision, )