forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Load Q4_0 gguf model and run in eager mode (pytorch#2510)
Summary: Pull Request resolved: pytorch#2510 As titled bypass-github-export-checks Reviewed By: mergennachin Differential Revision: D55056265 fbshipit-source-id: bd5631c21467cc12540055bd96ed8fc795ceafe7
- Loading branch information
1 parent
04dc65a
commit 8d55f91
Showing
3 changed files
with
381 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py | ||
import argparse | ||
|
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from executorch.examples.models.llama2.experimental.load_gguf_q4_0 import load_gguf_q4_0 | ||
from sentencepiece import SentencePieceProcessor | ||
|
||
|
||
def multinomial_sample_one_no_sync( | ||
probs_sort, | ||
): # Does multinomial sampling without a cuda synchronization | ||
q = torch.empty_like(probs_sort).exponential_(1) | ||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) | ||
|
||
|
||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): | ||
logits = logits / max(temperature, 1e-5) | ||
|
||
if top_k is not None: | ||
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | ||
pivot = v.select(-1, -1).unsqueeze(-1) | ||
logits = torch.where(logits < pivot, -float("Inf"), logits) | ||
probs = torch.nn.functional.softmax(logits, dim=-1) | ||
return probs | ||
|
||
|
||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): | ||
probs = logits_to_probs(logits[0, -1], temperature, top_k) | ||
idx_next = multinomial_sample_one_no_sync(probs) | ||
return idx_next, probs | ||
|
||
|
||
def encode_tokens(tokenizer, string, bos=True, device="cpu"): | ||
tokens = tokenizer.encode(string) | ||
if bos: | ||
tokens = [tokenizer.bos_id()] + tokens | ||
return torch.tensor(tokens, dtype=torch.int, device=device) | ||
|
||
|
||
def decode_one_token( | ||
model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
logits = model(x) | ||
return sample(logits, **sampling_kwargs) | ||
|
||
|
||
def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor: | ||
return decode_one_token(model, x, **sampling_kwargs)[0] | ||
|
||
|
||
def decode_n_tokens( | ||
model: torch.nn.Module, | ||
cur_token: torch.Tensor, | ||
num_new_tokens: int, | ||
callback=lambda _: _, | ||
**sampling_kwargs, | ||
): | ||
print(f"cur_token: {cur_token}") | ||
new_tokens, new_probs = [], [] | ||
for _ in range(num_new_tokens): | ||
with torch.backends.cuda.sdp_kernel( | ||
enable_flash=False, enable_mem_efficient=False, enable_math=True | ||
): # Actually better for Inductor to codegen attention here | ||
next_token, next_prob = decode_one_token( | ||
model, cur_token.view(1, -1), **sampling_kwargs | ||
) | ||
new_tokens.append(next_token.clone()) | ||
# print(next_token) | ||
callback(next_token) | ||
new_probs.append(next_prob.clone()) | ||
cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0) | ||
# print(cur_token) | ||
|
||
return new_tokens, new_probs | ||
|
||
|
||
@torch.no_grad() | ||
def generate( | ||
model: torch.nn.Module, | ||
prompt: torch.Tensor, | ||
max_new_tokens: int, | ||
*, | ||
interactive: bool, | ||
callback=lambda x: x, | ||
**sampling_kwargs, | ||
) -> torch.Tensor: | ||
""" | ||
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. | ||
""" | ||
|
||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
T = prompt.size(0) | ||
T_new = T + max_new_tokens | ||
# if interactive: | ||
# max_seq_length = 350 | ||
# else: | ||
# max_seq_length = min(T_new, model.params.max_seq_len) | ||
|
||
device, dtype = prompt.device, prompt.dtype | ||
|
||
# with torch.device(device): | ||
# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) | ||
|
||
# create an empty tensor of the expected final shape and fill in the current tokens | ||
empty = torch.empty(T_new, dtype=dtype, device=device) | ||
empty[:T] = prompt | ||
seq = empty | ||
# input_pos = torch.arange(0, T, device=device) | ||
|
||
next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs) | ||
seq[T] = next_token | ||
callback(next_token) | ||
|
||
cur_tokens = torch.cat((prompt, next_token), dim=0) | ||
# input_pos = torch.tensor([T], device=device, dtype=torch.int) | ||
|
||
generated_tokens, _ = decode_n_tokens( | ||
model, | ||
cur_tokens.view(1, -1), | ||
# input_pos, | ||
max_new_tokens - 1, | ||
callback=callback, | ||
**sampling_kwargs, | ||
) | ||
seq[T + 1 :] = torch.cat(generated_tokens) | ||
|
||
return seq | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--gguf_file", | ||
type=str, | ||
help="The GGUF file to load.", | ||
) | ||
parser.add_argument( | ||
"--tokenizer_path", | ||
type=str, | ||
help="The tokenizer.model path.", | ||
) | ||
parser.add_argument( | ||
"--prompt", type=str, default="Hello, my name is", help="Input prompt." | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path)) | ||
encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu") | ||
|
||
pt_model = load_gguf_q4_0(args.gguf_file) | ||
|
||
max_new_tokens = 100 | ||
buffer = [tokenizer.decode(encoded.tolist())] | ||
period_id = tokenizer.encode(".")[0] | ||
done_generating = False | ||
|
||
def callback(x): | ||
nonlocal done_generating | ||
if done_generating: | ||
return | ||
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) | ||
if x.item() == tokenizer.eos_id(): | ||
done_generating = True | ||
if len(buffer) == 4 or done_generating: | ||
print("".join(buffer), end="", flush=True) | ||
buffer.clear() | ||
|
||
generate( | ||
pt_model, | ||
encoded, | ||
max_new_tokens, | ||
interactive=False, | ||
callback=callback, | ||
temperature=1.0, | ||
top_k=10, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Load llama model from a GGUF file, quantized in Q4_0 format. | ||
# For float weights, we load them directly from the GGUF file. | ||
# For Q4_0 weights, we load them into a Tensor subclass (GGMLInt4LinearWeight). | ||
# This is done by replacing the linear weight with the subclass. | ||
|
||
import logging | ||
import os | ||
from typing import Callable, Dict, Mapping | ||
|
||
import torch | ||
from executorch.examples.models.llama2.experimental.subclass import ( | ||
_unpack_two_uint8, | ||
GGMLInt4LinearWeight, | ||
to_float, | ||
) | ||
from executorch.extension.gguf_util.converters.llama_converter import ( | ||
_convert_gguf_tensor_name_to_llama_nn, | ||
_create_pt_model, | ||
) | ||
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file | ||
from gguf import ReaderTensor | ||
from gguf.constants import GGMLQuantizationType | ||
from torchao.quantization.subclass import QuantizedLinearWeightBase | ||
|
||
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" | ||
logging.basicConfig(level=logging.INFO, format=FORMAT) | ||
|
||
|
||
def _replace_with_custom_fn_if_matches_filter( | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
model, | ||
replacement_fn, | ||
filter_fn, | ||
cur_fqn="", | ||
) -> None: | ||
""" | ||
For each `child` in `model`, replaces it with `replacement_fn(child)` | ||
if `filter_fn(child)` is `True` | ||
""" | ||
if filter_fn(model, cur_fqn[:-1]): | ||
model = replacement_fn(model, cur_fqn[:-1]) | ||
return model | ||
else: | ||
for name, child in model.named_children(): | ||
new_child = _replace_with_custom_fn_if_matches_filter( | ||
child, replacement_fn, filter_fn, f"{cur_fqn}{name}." | ||
) | ||
if new_child is not child: | ||
setattr(model, name, new_child) | ||
return model | ||
|
||
|
||
# pyre-fixme[3]: Return type must be annotated. | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
def _get_subclass_inserter( | ||
weight_map: Dict[str, ReaderTensor] | ||
) -> Callable[[torch.nn.Module, str], torch.nn.Module]: | ||
def insert_subclass(lin, fqn): | ||
# TODO: replace weights with gguf format tensor | ||
# packed tensor should have size [numel / 32, 18] | ||
fqn = fqn + ".weight" | ||
assert ( | ||
fqn in weight_map | ||
), f"Expect {fqn} to be in weight map but not found. All keys are {weight_map.keys()}" | ||
tensor = weight_map[fqn] | ||
print(fqn, tensor.shape, tensor.data.shape, lin.weight.shape) | ||
packed = torch.from_numpy(tensor.data).reshape(-1, 18) | ||
scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16) | ||
lin.weight = torch.nn.Parameter( | ||
GGMLInt4LinearWeight(packed, scale, lin.weight.shape) | ||
) | ||
return lin | ||
|
||
return insert_subclass | ||
|
||
|
||
# pyre-fixme[3]: Return type must be annotated. | ||
# pyre-fixme[2]: Parameter must be annotated. | ||
def _get_filter_fn( | ||
weight_map: Dict[str, ReaderTensor] | ||
) -> Callable[[torch.nn.Module, str], bool]: | ||
def _is_linear(mod, fqn): | ||
return ( | ||
isinstance(mod, torch.nn.Linear) | ||
and hasattr(mod, "weight") | ||
and weight_map[fqn + ".weight"].tensor_type == GGMLQuantizationType.Q4_0 | ||
and not isinstance(mod.weight, QuantizedLinearWeightBase) | ||
) | ||
|
||
return _is_linear | ||
|
||
|
||
def change_linear_weights_to_q4_0_tensors( | ||
model: torch.nn.Module, gguf_weights: GGUFWeights | ||
) -> None: | ||
""" | ||
Converts all linear weight tensors to the | ||
`GGMLInt4LinearWeight` tensor subclass, | ||
effectively applying the same form of quantization | ||
as apply_dynamic_quant while not modifying the linear modules. | ||
""" | ||
assert gguf_weights is not None, "Must provide gguf_weights" | ||
weight_map = { | ||
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor | ||
for tensor in gguf_weights.tensors | ||
} | ||
|
||
_replace_with_custom_fn_if_matches_filter( | ||
model, | ||
_get_subclass_inserter(weight_map), | ||
_get_filter_fn(weight_map), | ||
) | ||
|
||
|
||
def get_float_weights( | ||
pt_model: torch.nn.Module, gguf_weights: GGUFWeights | ||
) -> Mapping[str, torch.Tensor]: | ||
""" | ||
Returns a mapping from the fqn to the float weight tensor. Even though | ||
the model is quantized in Q4_0, these weights are still stored as float. | ||
Args: | ||
pt_model (torch.nn.Module): The model to load the weights. | ||
gguf_weights (GGUFWeights): The weights to extract the weights from. | ||
""" | ||
state_dict = {} | ||
for tensor in gguf_weights.tensors: | ||
model_key = _convert_gguf_tensor_name_to_llama_nn(tensor.name) | ||
if ( | ||
tensor.tensor_type == GGMLQuantizationType.F32 | ||
or tensor.tensor_type == GGMLQuantizationType.F16 | ||
): | ||
print(tensor.name) | ||
reversed_shape = tensor.shape[::-1] | ||
new_tensor = tensor.data.reshape(reversed_shape) | ||
state_dict[model_key] = torch.from_numpy(new_tensor) | ||
# Load token_embd.weight which is quantized in Q4_0 and we dequantize it into float. | ||
elif tensor.tensor_type == GGMLQuantizationType.Q4_0: | ||
if tensor.name == "token_embd.weight": | ||
print(tensor.name) | ||
unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18))) | ||
state_dict[model_key] = unpacked.reshape( | ||
pt_model.params.vocab_size, pt_model.params.dim | ||
) | ||
|
||
# We need to fake initialize the mask, to match with the llama_transformer.py | ||
for id in range(pt_model.params.n_layers): | ||
mask_name = f"layers.{id}.attention.mask" | ||
mask = torch.full( | ||
(1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len), | ||
float("-inf"), | ||
) | ||
mask = torch.triu(mask, diagonal=1) | ||
state_dict[mask_name] = mask | ||
return state_dict | ||
|
||
|
||
def load_gguf_q4_0(gguf_file: str) -> torch.nn.Module: | ||
assert os.path.isfile(gguf_file), f"Expect a valid gguf_file path, got {gguf_file}" | ||
|
||
logging.info(f"Loading GGUF file: {gguf_file}") | ||
gguf_model_args, gguf_weights = load_file(gguf_file) | ||
|
||
logging.info("Creating the PyTorch model") | ||
pt_model = _create_pt_model( | ||
gguf_model_args, | ||
) | ||
|
||
logging.info("Load float weights") | ||
state_dict = get_float_weights(pt_model, gguf_weights) | ||
pt_model.load_state_dict(state_dict, strict=False) | ||
|
||
logging.info("Change linear weights to Q4_0 tensors") | ||
change_linear_weights_to_q4_0_tensors(pt_model, gguf_weights) | ||
|
||
pt_model = pt_model.to(dtype=torch.float16) | ||
|
||
return pt_model |