Skip to content

Commit

Permalink
wip: 1.625 bpw ternary packing scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
compilade committed Jun 19, 2024
1 parent 89c7e4c commit 143b4b6
Show file tree
Hide file tree
Showing 11 changed files with 495 additions and 49 deletions.
73 changes: 39 additions & 34 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,27 @@ def write_tensors(self):
))

if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 and not any(
self.match_model_tensor_name(new_name, key, None)
for key in [
gguf.MODEL_TENSOR.TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
]
):
data = gguf.quantize_q1_3(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q1_3

elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data = gguf.quantize_bf16(data)
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16

elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
elif (
self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0
or self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3
and gguf.can_quantize_to_q8_0(data)
):
data = gguf.quantize_q8_0(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q8_0
Expand Down Expand Up @@ -1401,6 +1416,12 @@ def write_tensors(self):
class BitnetModel(Model):
model_arch = gguf.MODEL_ARCH.BITNET

def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, model_name: str | None):
if ftype == gguf.LlamaFileType.GUESSED:
ftype = gguf.LlamaFileType.MOSTLY_Q1_3

super().__init__(dir_model, ftype, fname_out, is_big_endian, use_temp_file, eager, model_name)

def set_vocab(self):
self._set_vocab_sentencepiece()

Expand All @@ -1420,40 +1441,24 @@ def weight_quant(self, weight):
return weight.type(dtype), scale.type(torch.float32)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# transform weight into 1/0/-1 (in fp32)
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
weight_torch, scale_torch = self.weight_quant(data_torch)

tensors: list[tuple[str, Tensor]] = []

if name.endswith("q_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch))
elif name.endswith("k_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch))
elif name.endswith("v_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch))
elif name.endswith("o_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch))
elif name.endswith("up_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch))
elif name.endswith("down_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch))
elif name.endswith("gate_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch))

if len(tensors) == 0:
tensors.append((self.map_tensor_name(name), data_torch))
new_name = self.map_tensor_name(name)

return tensors
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
gguf.MODEL_TENSOR.ATTN_Q,
gguf.MODEL_TENSOR.ATTN_K,
gguf.MODEL_TENSOR.ATTN_V,
gguf.MODEL_TENSOR.ATTN_OUT,
gguf.MODEL_TENSOR.FFN_UP,
gguf.MODEL_TENSOR.FFN_DOWN,
gguf.MODEL_TENSOR.FFN_GATE,
]):
# transform weight into 1/0/-1 (in fp32)
weight_torch, scale_torch = self.weight_quant(data_torch)
yield (new_name, weight_torch)
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
else:
yield (new_name, data_torch)


@Model.register("GrokForCausalLM")
Expand Down
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet 1.58b" },
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
Expand Down
44 changes: 44 additions & 0 deletions ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ typedef sycl::half2 ggml_half2;

#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP

// 1.625 bpw for BitNet 1.58b models
#define QK1_3 64
typedef struct {
uint8_t q[(QK1_3 - 4*QK1_3/64)/5]; // 5 elements per byte (3^5 = 243 < 256)
uint8_t qs[QK1_3/64]; // 4 elements per byte
} block_q1_3;
static_assert(sizeof(block_q1_3) == (QK1_3 - 4*QK1_3/64)/5 + QK1_3/64, "wrong q1_3 block size/padding");

#define QK2_2 32
typedef struct {
uint8_t qs[QK2_2 / 4]; // nibbles / quants
Expand Down Expand Up @@ -339,6 +347,7 @@ typedef struct {
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");

// 1.5625 bpw
typedef struct {
ggml_half d;
uint8_t qs[QK_K/8];
Expand Down Expand Up @@ -1095,6 +1104,41 @@ GGML_TABLE_BEGIN(uint32_t, q22_grid, 256)
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
GGML_TABLE_END()

GGML_TABLE_BEGIN(uint32_t, q1_3_grid, 256)
0xffffffff, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff,
0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff000000, 0xff000001,
0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff,
0xff010000, 0xff010001, 0xff0101ff, 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01,
0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00,
0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101,
0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, 0x00010001, 0x000101ff, 0x00010100,
0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001,
0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000,
0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x0101ff01,
0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00,
0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff,
0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100,
0xff000101, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff,
0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0000,
0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff,
0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01,
0x000100ff, 0x00010000, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff,
0x01ffff00, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101,
0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x01000001, 0x010001ff,
0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001,
0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000,
0xffff0001, 0xffff01ff, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01,
0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ff00,
0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, 0xff0101ff, 0xff010100, 0xff010101,
0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100,
0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff,
0x00000100, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000,
0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ff00ff,
0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x01ff0101, 0x0100ffff, 0x0100ff00,
0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff,
0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101,
GGML_TABLE_END()

#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
Expand Down
Loading

0 comments on commit 143b4b6

Please sign in to comment.