Skip to content

Commit

Permalink
Load Q4_0 gguf model and run in eager mode (pytorch#2510)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2510

As titled

bypass-github-export-checks

Reviewed By: mergennachin

Differential Revision: D55056265

fbshipit-source-id: bd5631c21467cc12540055bd96ed8fc795ceafe7
  • Loading branch information
larryliu0820 authored and facebook-github-bot committed Mar 20, 2024
1 parent 04dc65a commit 8d55f91
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 2 deletions.
10 changes: 8 additions & 2 deletions examples/models/llama2/experimental/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ cmake --build . --config Release
# prepare model
python3 convert.py models/llama7b

# quantize
build/bin/quantize models/llama7b/ggml-model-f16.gguf models/llama7b/ggml-model-Q4_0.gguf Q4_0
# quantize. Notice we use --pure to avoid Q6_K from showing up.
build/bin/quantize --pure models/llama7b/ggml-model-f16.gguf models/llama7b/ggml-model-Q4_0.gguf Q4_0

```

We want to load it back into a `torch.nn.Module` and run in eager mode. The way it works is through a Tensor subclass.


## Generate Tokens in PyTorch Eager
```bash
python3 generate.py --prompt "Once upon a time" --gguf_file models/llama7b/ggml-model-Q4_0.gguf --tokenizer_path models/llama7b/tokenizer.model
```
190 changes: 190 additions & 0 deletions examples/models/llama2/experimental/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Adapted from gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
import argparse

from typing import Optional, Tuple

import torch

from executorch.examples.models.llama2.experimental.load_gguf_q4_0 import load_gguf_q4_0
from sentencepiece import SentencePieceProcessor


def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)

if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs


def encode_tokens(tokenizer, string, bos=True, device="cpu"):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)


def decode_one_token(
model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
logits = model(x)
return sample(logits, **sampling_kwargs)


def prefill(model: torch.nn.Module, x: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
return decode_one_token(model, x, **sampling_kwargs)[0]


def decode_n_tokens(
model: torch.nn.Module,
cur_token: torch.Tensor,
num_new_tokens: int,
callback=lambda _: _,
**sampling_kwargs,
):
print(f"cur_token: {cur_token}")
new_tokens, new_probs = [], []
for _ in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token.view(1, -1), **sampling_kwargs
)
new_tokens.append(next_token.clone())
# print(next_token)
callback(next_token)
new_probs.append(next_prob.clone())
cur_token = torch.cat((cur_token.squeeze(), next_token), dim=0)
# print(cur_token)

return new_tokens, new_probs


@torch.no_grad()
def generate(
model: torch.nn.Module,
prompt: torch.Tensor,
max_new_tokens: int,
*,
interactive: bool,
callback=lambda x: x,
**sampling_kwargs,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""

# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
# if interactive:
# max_seq_length = 350
# else:
# max_seq_length = min(T_new, model.params.max_seq_len)

device, dtype = prompt.device, prompt.dtype

# with torch.device(device):
# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
# input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), **sampling_kwargs)
seq[T] = next_token
callback(next_token)

cur_tokens = torch.cat((prompt, next_token), dim=0)
# input_pos = torch.tensor([T], device=device, dtype=torch.int)

generated_tokens, _ = decode_n_tokens(
model,
cur_tokens.view(1, -1),
# input_pos,
max_new_tokens - 1,
callback=callback,
**sampling_kwargs,
)
seq[T + 1 :] = torch.cat(generated_tokens)

return seq


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--gguf_file",
type=str,
help="The GGUF file to load.",
)
parser.add_argument(
"--tokenizer_path",
type=str,
help="The tokenizer.model path.",
)
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
)

args = parser.parse_args()

tokenizer = SentencePieceProcessor(model_file=str(args.tokenizer_path))
encoded = encode_tokens(tokenizer, args.prompt, bos=True, device="cpu")

pt_model = load_gguf_q4_0(args.gguf_file)

max_new_tokens = 100
buffer = [tokenizer.decode(encoded.tolist())]
period_id = tokenizer.encode(".")[0]
done_generating = False

def callback(x):
nonlocal done_generating
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:])
if x.item() == tokenizer.eos_id():
done_generating = True
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()

generate(
pt_model,
encoded,
max_new_tokens,
interactive=False,
callback=callback,
temperature=1.0,
top_k=10,
)


if __name__ == "__main__":
main()
183 changes: 183 additions & 0 deletions examples/models/llama2/experimental/load_gguf_q4_0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Load llama model from a GGUF file, quantized in Q4_0 format.
# For float weights, we load them directly from the GGUF file.
# For Q4_0 weights, we load them into a Tensor subclass (GGMLInt4LinearWeight).
# This is done by replacing the linear weight with the subclass.

import logging
import os
from typing import Callable, Dict, Mapping

import torch
from executorch.examples.models.llama2.experimental.subclass import (
_unpack_two_uint8,
GGMLInt4LinearWeight,
to_float,
)
from executorch.extension.gguf_util.converters.llama_converter import (
_convert_gguf_tensor_name_to_llama_nn,
_create_pt_model,
)
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
from gguf import ReaderTensor
from gguf.constants import GGMLQuantizationType
from torchao.quantization.subclass import QuantizedLinearWeightBase

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def _replace_with_custom_fn_if_matches_filter(
# pyre-fixme[2]: Parameter must be annotated.
model,
replacement_fn,
filter_fn,
cur_fqn="",
) -> None:
"""
For each `child` in `model`, replaces it with `replacement_fn(child)`
if `filter_fn(child)` is `True`
"""
if filter_fn(model, cur_fqn[:-1]):
model = replacement_fn(model, cur_fqn[:-1])
return model
else:
for name, child in model.named_children():
new_child = _replace_with_custom_fn_if_matches_filter(
child, replacement_fn, filter_fn, f"{cur_fqn}{name}."
)
if new_child is not child:
setattr(model, name, new_child)
return model


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_subclass_inserter(
weight_map: Dict[str, ReaderTensor]
) -> Callable[[torch.nn.Module, str], torch.nn.Module]:
def insert_subclass(lin, fqn):
# TODO: replace weights with gguf format tensor
# packed tensor should have size [numel / 32, 18]
fqn = fqn + ".weight"
assert (
fqn in weight_map
), f"Expect {fqn} to be in weight map but not found. All keys are {weight_map.keys()}"
tensor = weight_map[fqn]
print(fqn, tensor.shape, tensor.data.shape, lin.weight.shape)
packed = torch.from_numpy(tensor.data).reshape(-1, 18)
scale = torch.tensor(_unpack_two_uint8(packed[:, :2]), dtype=torch.float16)
lin.weight = torch.nn.Parameter(
GGMLInt4LinearWeight(packed, scale, lin.weight.shape)
)
return lin

return insert_subclass


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_filter_fn(
weight_map: Dict[str, ReaderTensor]
) -> Callable[[torch.nn.Module, str], bool]:
def _is_linear(mod, fqn):
return (
isinstance(mod, torch.nn.Linear)
and hasattr(mod, "weight")
and weight_map[fqn + ".weight"].tensor_type == GGMLQuantizationType.Q4_0
and not isinstance(mod.weight, QuantizedLinearWeightBase)
)

return _is_linear


def change_linear_weights_to_q4_0_tensors(
model: torch.nn.Module, gguf_weights: GGUFWeights
) -> None:
"""
Converts all linear weight tensors to the
`GGMLInt4LinearWeight` tensor subclass,
effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
"""
assert gguf_weights is not None, "Must provide gguf_weights"
weight_map = {
_convert_gguf_tensor_name_to_llama_nn(tensor.name): tensor
for tensor in gguf_weights.tensors
}

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(weight_map),
_get_filter_fn(weight_map),
)


def get_float_weights(
pt_model: torch.nn.Module, gguf_weights: GGUFWeights
) -> Mapping[str, torch.Tensor]:
"""
Returns a mapping from the fqn to the float weight tensor. Even though
the model is quantized in Q4_0, these weights are still stored as float.
Args:
pt_model (torch.nn.Module): The model to load the weights.
gguf_weights (GGUFWeights): The weights to extract the weights from.
"""
state_dict = {}
for tensor in gguf_weights.tensors:
model_key = _convert_gguf_tensor_name_to_llama_nn(tensor.name)
if (
tensor.tensor_type == GGMLQuantizationType.F32
or tensor.tensor_type == GGMLQuantizationType.F16
):
print(tensor.name)
reversed_shape = tensor.shape[::-1]
new_tensor = tensor.data.reshape(reversed_shape)
state_dict[model_key] = torch.from_numpy(new_tensor)
# Load token_embd.weight which is quantized in Q4_0 and we dequantize it into float.
elif tensor.tensor_type == GGMLQuantizationType.Q4_0:
if tensor.name == "token_embd.weight":
print(tensor.name)
unpacked = to_float(torch.from_numpy(tensor.data.reshape(-1, 18)))
state_dict[model_key] = unpacked.reshape(
pt_model.params.vocab_size, pt_model.params.dim
)

# We need to fake initialize the mask, to match with the llama_transformer.py
for id in range(pt_model.params.n_layers):
mask_name = f"layers.{id}.attention.mask"
mask = torch.full(
(1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len),
float("-inf"),
)
mask = torch.triu(mask, diagonal=1)
state_dict[mask_name] = mask
return state_dict


def load_gguf_q4_0(gguf_file: str) -> torch.nn.Module:
assert os.path.isfile(gguf_file), f"Expect a valid gguf_file path, got {gguf_file}"

logging.info(f"Loading GGUF file: {gguf_file}")
gguf_model_args, gguf_weights = load_file(gguf_file)

logging.info("Creating the PyTorch model")
pt_model = _create_pt_model(
gguf_model_args,
)

logging.info("Load float weights")
state_dict = get_float_weights(pt_model, gguf_weights)
pt_model.load_state_dict(state_dict, strict=False)

logging.info("Change linear weights to Q4_0 tensors")
change_linear_weights_to_q4_0_tensors(pt_model, gguf_weights)

pt_model = pt_model.to(dtype=torch.float16)

return pt_model

0 comments on commit 8d55f91

Please sign in to comment.