diff --git a/examples/models/llama2/experimental/README.md b/examples/models/llama2/experimental/README.md index e267f6fe5c..964030d7c4 100644 --- a/examples/models/llama2/experimental/README.md +++ b/examples/models/llama2/experimental/README.md @@ -22,9 +22,15 @@ cmake --build . --config Release # prepare model python3 convert.py models/llama7b -# quantize -build/bin/quantize models/llama7b/ggml-model-f16.gguf models/llama7b/ggml-model-Q4_0.gguf Q4_0 +# quantize. Notice we use --pure to avoid Q6_K from showing up. +build/bin/quantize --pure models/llama7b/ggml-model-f16.gguf models/llama7b/ggml-model-Q4_0.gguf Q4_0 ``` We want to load it back into a `torch.nn.Module` and run in eager mode. The way it works is through a Tensor subclass. + + +## Generate Tokens in PyTorch Eager +```bash +python3 generate.py --prompt "Once upon a time" --gguf_file models/llama7b/ggml-model-Q4_0.gguf --tokenizer_path models/llama7b/tokenizer.model +``` diff --git a/examples/models/llama2/experimental/generate.py b/examples/models/llama2/experimental/generate.py new file mode 100644 index 0000000000..bc974d7351 --- /dev/null +++ b/examples/models/llama2/experimental/generate.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py +import argparse + +from typing import Optional, Tuple + +import torch + +from executorch.examples.models.llama2.experimental.load_gguf_q4_0 import load_gguf_q4_0 +from sentencepiece import SentencePieceProcessor + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[0, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def encode_tokens(tokenizer, string, bos=True, device="cpu"): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def decode_one_token( + model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + logits = model(x) + return sample(logits, **sampling_kwargs) + + +def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor: + return decode_one_token(model, x, **sampling_kwargs)[0] + + +def decode_n_tokens( + model: torch.nn.Module, + cur_token: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + print(f"cur_token: {cur_token}") + new_tokens, new_probs = [], [] + for _ in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token.view(1, -1), **sampling_kwargs + ) + new_tokens.append(next_token.clone()) + # print(next_token) + callback(next_token) + new_probs.append(next_prob.clone()) + cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0) + # print(cur_token) + + return new_tokens, new_probs + + +@torch.no_grad() +def generate( + model: torch.nn.Module, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + callback=lambda x: x, + **sampling_kwargs, +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + T = prompt.size(0) + T_new = T + max_new_tokens + # if interactive: + # max_seq_length = 350 + # else: + # max_seq_length = min(T_new, model.params.max_seq_len) + + device, dtype = prompt.device, prompt.dtype + + # with torch.device(device): + # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + + # create an empty tensor of the expected final shape and fill in the current tokens + empty = torch.empty(T_new, dtype=dtype, device=device) + empty[:T] = prompt + seq = empty + # input_pos = torch.arange(0, T, device=device) + + next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs) + seq[T] = next_token + callback(next_token) + + cur_tokens = torch.cat((prompt, next_token), dim=0) + # input_pos = torch.tensor([T], device=device, dtype=torch.int) + + generated_tokens, _ = decode_n_tokens( + model, + cur_tokens.view(1, -1), + # input_pos, + max_new_tokens - 1, + callback=callback, + **sampling_kwargs, + ) + seq[T + 1 :] = torch.cat(generated_tokens) + + return seq + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--gguf_file", + type=str, + help="The GGUF file to load.", + ) + parser.add_argument( + "--tokenizer_path", + type=str, + help="The tokenizer.model path.", + ) + parser.add_argument( + "--prompt", type=str, default="Hello, my name is", help="Input prompt." + ) + + args = parser.parse_args() + + tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path)) + encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu") + + pt_model = load_gguf_q4_0(args.gguf_file) + + max_new_tokens = 100 + buffer = [tokenizer.decode(encoded.tolist())] + period_id = tokenizer.encode(".")[0] + done_generating = False + + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print("".join(buffer), end="", flush=True) + buffer.clear() + + generate( + pt_model, + encoded, + max_new_tokens, + interactive=False, + callback=callback, + temperature=1.0, + top_k=10, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama2/experimental/load_gguf_q4_0.py b/examples/models/llama2/experimental/load_gguf_q4_0.py new file mode 100644 index 0000000000..4583978394 --- /dev/null +++ b/examples/models/llama2/experimental/load_gguf_q4_0.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Load llama model from a GGUF file, quantized in Q4_0 format. +# For float weights, we load them directly from the GGUF file. +# For Q4_0 weights, we load them into a Tensor subclass (GGMLInt4LinearWeight). +# This is done by replacing the linear weight with the subclass. + +import logging +import os +from typing import Callable, Dict, Mapping + +import torch +from executorch.examples.models.llama2.experimental.subclass import ( + _unpack_two_uint8, + GGMLInt4LinearWeight, + to_float, +) +from executorch.extension.gguf_util.converters.llama_converter import ( + _convert_gguf_tensor_name_to_llama_nn, + _create_pt_model, +) +from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file +from gguf import ReaderTensor +from gguf.constants import GGMLQuantizationType +from torchao.quantization.subclass import QuantizedLinearWeightBase + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def _replace_with_custom_fn_if_matches_filter( + # pyre-fixme[2]: Parameter must be annotated. + model, + replacement_fn, + filter_fn, + cur_fqn="", +) -> None: + """ + For each `child` in `model`, replaces it with `replacement_fn(child)` + if `filter_fn(child)` is `True` + """ + if filter_fn(model, cur_fqn[:-1]): + model = replacement_fn(model, cur_fqn[:-1]) + return model + else: + for name, child in model.named_children(): + new_child = _replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, f"{cur_fqn}{name}." + ) + if new_child is not child: + setattr(model, name, new_child) + return model + + +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def _get_subclass_inserter( + weight_map: Dict[str, ReaderTensor] +) -> Callable[[torch.nn.Module, str], torch.nn.Module]: + def insert_subclass(lin, fqn): + # TODO: replace weights with gguf format tensor + # packed tensor should have size [numel / 32, 18] + fqn = fqn + ".weight" + assert ( + fqn in weight_map + ), f"Expect {fqn} to be in weight map but not found. All keys are {weight_map.keys()}" + tensor = weight_map[fqn] + print(fqn, tensor.shape, tensor.data.shape, lin.weight.shape) + packed = torch.from_numpy(tensor.data).reshape(-1, 18) + scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16) + lin.weight = torch.nn.Parameter( + GGMLInt4LinearWeight(packed, scale, lin.weight.shape) + ) + return lin + + return insert_subclass + + +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def _get_filter_fn( + weight_map: Dict[str, ReaderTensor] +) -> Callable[[torch.nn.Module, str], bool]: + def _is_linear(mod, fqn): + return ( + isinstance(mod, torch.nn.Linear) + and hasattr(mod, "weight") + and weight_map[fqn + ".weight"].tensor_type == GGMLQuantizationType.Q4_0 + and not isinstance(mod.weight, QuantizedLinearWeightBase) + ) + + return _is_linear + + +def change_linear_weights_to_q4_0_tensors( + model: torch.nn.Module, gguf_weights: GGUFWeights +) -> None: + """ + Converts all linear weight tensors to the + `GGMLInt4LinearWeight` tensor subclass, + effectively applying the same form of quantization + as apply_dynamic_quant while not modifying the linear modules. + """ + assert gguf_weights is not None, "Must provide gguf_weights" + weight_map = { + _convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor + for tensor in gguf_weights.tensors + } + + _replace_with_custom_fn_if_matches_filter( + model, + _get_subclass_inserter(weight_map), + _get_filter_fn(weight_map), + ) + + +def get_float_weights( + pt_model: torch.nn.Module, gguf_weights: GGUFWeights +) -> Mapping[str, torch.Tensor]: + """ + Returns a mapping from the fqn to the float weight tensor. Even though + the model is quantized in Q4_0, these weights are still stored as float. + Args: + pt_model (torch.nn.Module): The model to load the weights. + gguf_weights (GGUFWeights): The weights to extract the weights from. + """ + state_dict = {} + for tensor in gguf_weights.tensors: + model_key = _convert_gguf_tensor_name_to_llama_nn(tensor.name) + if ( + tensor.tensor_type == GGMLQuantizationType.F32 + or tensor.tensor_type == GGMLQuantizationType.F16 + ): + print(tensor.name) + reversed_shape = tensor.shape[::-1] + new_tensor = tensor.data.reshape(reversed_shape) + state_dict[model_key] = torch.from_numpy(new_tensor) + # Load token_embd.weight which is quantized in Q4_0 and we dequantize it into float. + elif tensor.tensor_type == GGMLQuantizationType.Q4_0: + if tensor.name == "token_embd.weight": + print(tensor.name) + unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18))) + state_dict[model_key] = unpacked.reshape( + pt_model.params.vocab_size, pt_model.params.dim + ) + + # We need to fake initialize the mask, to match with the llama_transformer.py + for id in range(pt_model.params.n_layers): + mask_name = f"layers.{id}.attention.mask" + mask = torch.full( + (1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len), + float("-inf"), + ) + mask = torch.triu(mask, diagonal=1) + state_dict[mask_name] = mask + return state_dict + + +def load_gguf_q4_0(gguf_file: str) -> torch.nn.Module: + assert os.path.isfile(gguf_file), f"Expect a valid gguf_file path, got {gguf_file}" + + logging.info(f"Loading GGUF file: {gguf_file}") + gguf_model_args, gguf_weights = load_file(gguf_file) + + logging.info("Creating the PyTorch model") + pt_model = _create_pt_model( + gguf_model_args, + ) + + logging.info("Load float weights") + state_dict = get_float_weights(pt_model, gguf_weights) + pt_model.load_state_dict(state_dict, strict=False) + + logging.info("Change linear weights to Q4_0 tensors") + change_linear_weights_to_q4_0_tensors(pt_model, gguf_weights) + + pt_model = pt_model.to(dtype=torch.float16) + + return pt_model