Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BitnetForCausalLM (new model / new datatype) #7931

Merged
merged 38 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
076b4a1
hf bitnet v1
Eddie-Wang1120 Jun 5, 2024
57dfc3b
hf bitnet e2e v2
Eddie-Wang1120 Jun 5, 2024
1f2e0ee
finish bitnet e2e
Eddie-Wang1120 Jun 6, 2024
5e59660
finish f16 hf bitnet e2e
Eddie-Wang1120 Jun 7, 2024
2a01a7c
remove unsed
Eddie-Wang1120 Jun 7, 2024
4e1ab50
finish bitnet i2 e2e
Eddie-Wang1120 Jun 8, 2024
ca09085
move i2s to quantize v1
Eddie-Wang1120 Jun 9, 2024
dbee0a8
move i2 to quantize
Jun 9, 2024
1c5a8b7
clean code
Jun 9, 2024
3a0f8b0
clean code 2
Jun 9, 2024
97d22be
fix codestyle
Eddie-Wang1120 Jun 9, 2024
344467f
fix code
Eddie-Wang1120 Jun 9, 2024
65ac3a3
fix
Eddie-Wang1120 Jun 9, 2024
abd798d
fix code
Eddie-Wang1120 Jun 10, 2024
841c903
Merge branch 'ggerganov:master' into bitnet
Eddie-Wang1120 Jun 10, 2024
c0fd4df
fix merge
Eddie-Wang1120 Jun 10, 2024
de1d507
remove unused
Eddie-Wang1120 Jun 11, 2024
2322e9d
Merge branch 'ggerganov:master' into bitnet
Eddie-Wang1120 Jun 11, 2024
c0cd08d
Merge branch 'ggerganov:master' into bitnet
Eddie-Wang1120 Jun 12, 2024
f395dd9
change table name
Eddie-Wang1120 Jun 12, 2024
5e5eee7
fix whitespace
Eddie-Wang1120 Jun 12, 2024
7a8961f
delete redundant
Eddie-Wang1120 Jun 14, 2024
95dced0
i2_s to absmax
Eddie-Wang1120 Jun 15, 2024
569a03e
finish i2_s/i8_s vec_dot x86 simd
Eddie-Wang1120 Jun 15, 2024
a03eff3
i2s->q22
Eddie-Wang1120 Jun 17, 2024
4edc958
fix code
Eddie-Wang1120 Jun 18, 2024
89c7e4c
remove block scale
Eddie-Wang1120 Jun 18, 2024
fcf2da4
add dequantize
Eddie-Wang1120 Jun 19, 2024
fa9a742
fix seq
Eddie-Wang1120 Jun 19, 2024
230396b
update avx2
Eddie-Wang1120 Jun 19, 2024
2b09768
remove q2_2
Eddie-Wang1120 Jun 20, 2024
a58cf0d
remove q22_grid
Eddie-Wang1120 Jun 20, 2024
abcdc50
Merge branch 'ggerganov:master' into bitnet
Eddie-Wang1120 Jun 20, 2024
c6ddfa7
fix whitespace
Eddie-Wang1120 Jun 20, 2024
55a57a5
reuse llm_build_kv
Eddie-Wang1120 Jun 21, 2024
0520d88
Merge branch 'ggerganov:master' into bitnet
Eddie-Wang1120 Jun 21, 2024
16f0c30
Merge branch 'ggerganov:master' into bitnet
Eddie-Wang1120 Jun 23, 2024
226c5ee
fix bo
Eddie-Wang1120 Jun 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,35 @@ def write_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("BitnetForCausalLM")
class BitnetModel(Model):
model_arch = gguf.MODEL_ARCH.BITNET

def set_vocab(self):
self._set_vocab_sentencepiece()

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)

def weight_quant(self, weight):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# transform weight into 1/0/-1 (in fp32)
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
data_torch = data_torch + (self.weight_quant(data_torch) - data_torch).detach()
compilade marked this conversation as resolved.
Show resolved Hide resolved

return [(self.map_tensor_name(name), data_torch)]


@Model.register("GrokForCausalLM")
class GrokModel(Model):
model_arch = gguf.MODEL_ARCH.GROK
Expand Down
1 change: 1 addition & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
Expand Down
67 changes: 67 additions & 0 deletions ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,73 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()

GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256)
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00010100, 0x01010100, 0x00010100, 0xff010100,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00010001, 0x01010001, 0x00010001, 0xff010001,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
0x00000101, 0x01000101, 0x00000101, 0xff000101,
0x00010101, 0x01010101, 0x00010101, 0xff010101,
0x00000101, 0x01000101, 0x00000101, 0xff000101,
0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00010001, 0x01010001, 0x00010001, 0xff010001,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01,
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00010100, 0x01010100, 0x00010100, 0xff010100,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff,
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff,
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
GGML_TABLE_END()

#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
Expand Down
145 changes: 145 additions & 0 deletions ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) {
}
#endif //__loongarch_asx

void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) {
int8_t* dst = (int8_t*)y;
double min = 0.00001;
double max = min;
for (int i = 0; i < n; ++i) {
max = MAX(max, (double)fabs((double)x[i]));
}
float s = 127 / max;
act_scales[0] = s;
float temp;
for (int i = 0; i < n; ++i) {
temp = round((double)(x[i] * s));
if (temp > 127) temp = 127;
if (temp < -128) temp = -128;
dst[i] = (int8_t)(temp);
}
}

// reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
Expand Down Expand Up @@ -3306,6 +3324,53 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
return nrow * row_size;
}

size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
// 2 bits per weight
UNUSED(quant_weights);

size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);

int n = nrow * n_per_row;

// f32 -> q8
double i2_scale = 0;
for (int i=0; i<n; i++) {
if (fabs((double)(src[i])) > 1e-6) {
i2_scale = (double)src[i];
Copy link
Collaborator

@compilade compilade Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only taking the last non-zero value of the tensor as the scale, if I understand correctly?

The other quants use the absmax, so this looks a bit weird.

Does it work as expected? If so, how or why?

Should it be the absmean of the non-zero values instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for remind!
Actually, at this time, the weight matrix only contains three different value : (1, 0, -1) * scale (for example, 0.66, 0, -0.66). So, as long as we don't pick 0 as scale, any non-zero value could be the scale. If we pick 0.66 as scale, we will transform (0.66, 0, -0.66) to (1, 0, -1), if we pick -0.66 as scale, we will transform (0.66, 0, -0.66) to (-1, 0, 1).
I can add a break to the loop to pick the first non-zero value to avoid useless looping.

Copy link
Collaborator

@compilade compilade Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a break to the loop to pick the first non-zero value to avoid useless looping.

This would slightly complicate the eventual Numpy implementation in gguf-py/gguf/quants.py for convert-hf-to-gguf.py. Quantization doesn't need to be particularly fast (I think?), but it needs to be reproducible (and possibly sane on even non-bitnet models, since people will try to apply this on models for which it's not appropriate). If all absolute non-zero values are the same in bitnet models, picking the absmax might be fine then.

(It's dequantization that needs to be fast for good inference speed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it. I will change it into absmax to make it more reproducible.

Copy link
Collaborator

@compilade compilade Jun 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Eddie-Wang1120, actually (I only noticed it now), in section 2 of the BitNet 1.58b paper, they specifically say they use absmean:

To constrain the weights to -1, 0, or +1, we adopt an absmean quantization function. It first scales the weight matrix by its average absolute value, and then round each value to the nearest integer among {-1, 0, +1}

See https://arxiv.org/html/2402.17764v1#S2

But if it's applied twice (e.g. on pre-quantized weights), then maybe the mean shouldn't include the zero values. (absmax is still fine, but only when the weights are pre-quantized)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And an interesting fact is, if we don't pre-quantize weights to {-1, 0, +1} * scale, the tokens generated will be wrong. That's why I put the absmean quantization (weight pre-quantization) in convert-hf-to-gguf.py, otherwise we'll get a meaningless fp32/fp16 gguf model.

}
}

uint8_t* q8 = (uint8_t*)dst;
for (int i=0; i<n; i++) {
if (fabs((double)(src[i])) < 1e-6) {
q8[i] = 0;
continue;
}
q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3;
compilade marked this conversation as resolved.
Show resolved Hide resolved
}

// q8 -> 0, 1, 3
// | | |
// 0, 1,-1

uint8_t* i2_weight = (uint8_t*)dst;
for (int i=0; i<n; i++) {
int group_idx = i / 4;
int group_pos = i % 4;
uint8_t temp = (q8[i] << (6 - 2 * group_pos));
q8[i] = 0;
i2_weight[group_idx] |= temp;
}
Copy link
Collaborator

@compilade compilade Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder, maybe this could be made even more compact?

Instead of fitting only 4 ternary values per byte, it would be possible to fit 5 of them (because 3⁵ = 243, which is smaller than 256).

To avoid using modulo when dequantizing, assuming multiplication by 3 is fast (it can be turned into an addition and a bit shift), maybe storing an inverted value would work.

Not sure what speed difference it would have compared to bit shifts and masks, though.

Here's an example program verifying that multiplication can be an alternative to modulo by 3 (click to expand)
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>

int main() {
    char s1[6] = {0};
    char s2[6] = {0};
    for (uint8_t i = 0; i < 243; ++i) {
        uint8_t n = i;
        // extract with modulo
        for (int j = 5; j-- > 0;) {
            s1[j] = (n % 3) + '0';
            n /= 3;
        }
        // invert the value
        uint8_t q = (((uint16_t) i) * 256) / 243;
        if (q != 0) {
            // otherwise it's always one smaller than the original
            q += 1;
        }
        // extract with multiplication
        for (int j = 0; j < 5; ++j) {
            uint16_t m = q * 3;
            s2[j] = (m >> 8) + '0';
            q = m & 0xFF;
        }
        printf("%s, %s: %s\n", s1, s2, strcmp(s1, s2) == 0 ? "\033[1;32mPASS\033[0m" : "\033[1;31mFAIL\033[0m");
    }

    return 0;
}

Compile and run:

$ gcc ternary-packing.c -o ternary-packing
$ ./ternary-packing
Output (click to expand)
$ ./ternary-packing
00000, 00000: PASS
00001, 00001: PASS
00002, 00002: PASS
00010, 00010: PASS
00011, 00011: PASS
00012, 00012: PASS
00020, 00020: PASS
00021, 00021: PASS
00022, 00022: PASS
00100, 00100: PASS
00101, 00101: PASS
00102, 00102: PASS
00110, 00110: PASS
00111, 00111: PASS
00112, 00112: PASS
00120, 00120: PASS
00121, 00121: PASS
00122, 00122: PASS
00200, 00200: PASS
00201, 00201: PASS
00202, 00202: PASS
00210, 00210: PASS
00211, 00211: PASS
00212, 00212: PASS
00220, 00220: PASS
00221, 00221: PASS
00222, 00222: PASS
01000, 01000: PASS
01001, 01001: PASS
01002, 01002: PASS
01010, 01010: PASS
01011, 01011: PASS
01012, 01012: PASS
01020, 01020: PASS
01021, 01021: PASS
01022, 01022: PASS
01100, 01100: PASS
01101, 01101: PASS
01102, 01102: PASS
01110, 01110: PASS
01111, 01111: PASS
01112, 01112: PASS
01120, 01120: PASS
01121, 01121: PASS
01122, 01122: PASS
01200, 01200: PASS
01201, 01201: PASS
01202, 01202: PASS
01210, 01210: PASS
01211, 01211: PASS
01212, 01212: PASS
01220, 01220: PASS
01221, 01221: PASS
01222, 01222: PASS
02000, 02000: PASS
02001, 02001: PASS
02002, 02002: PASS
02010, 02010: PASS
02011, 02011: PASS
02012, 02012: PASS
02020, 02020: PASS
02021, 02021: PASS
02022, 02022: PASS
02100, 02100: PASS
02101, 02101: PASS
02102, 02102: PASS
02110, 02110: PASS
02111, 02111: PASS
02112, 02112: PASS
02120, 02120: PASS
02121, 02121: PASS
02122, 02122: PASS
02200, 02200: PASS
02201, 02201: PASS
02202, 02202: PASS
02210, 02210: PASS
02211, 02211: PASS
02212, 02212: PASS
02220, 02220: PASS
02221, 02221: PASS
02222, 02222: PASS
10000, 10000: PASS
10001, 10001: PASS
10002, 10002: PASS
10010, 10010: PASS
10011, 10011: PASS
10012, 10012: PASS
10020, 10020: PASS
10021, 10021: PASS
10022, 10022: PASS
10100, 10100: PASS
10101, 10101: PASS
10102, 10102: PASS
10110, 10110: PASS
10111, 10111: PASS
10112, 10112: PASS
10120, 10120: PASS
10121, 10121: PASS
10122, 10122: PASS
10200, 10200: PASS
10201, 10201: PASS
10202, 10202: PASS
10210, 10210: PASS
10211, 10211: PASS
10212, 10212: PASS
10220, 10220: PASS
10221, 10221: PASS
10222, 10222: PASS
11000, 11000: PASS
11001, 11001: PASS
11002, 11002: PASS
11010, 11010: PASS
11011, 11011: PASS
11012, 11012: PASS
11020, 11020: PASS
11021, 11021: PASS
11022, 11022: PASS
11100, 11100: PASS
11101, 11101: PASS
11102, 11102: PASS
11110, 11110: PASS
11111, 11111: PASS
11112, 11112: PASS
11120, 11120: PASS
11121, 11121: PASS
11122, 11122: PASS
11200, 11200: PASS
11201, 11201: PASS
11202, 11202: PASS
11210, 11210: PASS
11211, 11211: PASS
11212, 11212: PASS
11220, 11220: PASS
11221, 11221: PASS
11222, 11222: PASS
12000, 12000: PASS
12001, 12001: PASS
12002, 12002: PASS
12010, 12010: PASS
12011, 12011: PASS
12012, 12012: PASS
12020, 12020: PASS
12021, 12021: PASS
12022, 12022: PASS
12100, 12100: PASS
12101, 12101: PASS
12102, 12102: PASS
12110, 12110: PASS
12111, 12111: PASS
12112, 12112: PASS
12120, 12120: PASS
12121, 12121: PASS
12122, 12122: PASS
12200, 12200: PASS
12201, 12201: PASS
12202, 12202: PASS
12210, 12210: PASS
12211, 12211: PASS
12212, 12212: PASS
12220, 12220: PASS
12221, 12221: PASS
12222, 12222: PASS
20000, 20000: PASS
20001, 20001: PASS
20002, 20002: PASS
20010, 20010: PASS
20011, 20011: PASS
20012, 20012: PASS
20020, 20020: PASS
20021, 20021: PASS
20022, 20022: PASS
20100, 20100: PASS
20101, 20101: PASS
20102, 20102: PASS
20110, 20110: PASS
20111, 20111: PASS
20112, 20112: PASS
20120, 20120: PASS
20121, 20121: PASS
20122, 20122: PASS
20200, 20200: PASS
20201, 20201: PASS
20202, 20202: PASS
20210, 20210: PASS
20211, 20211: PASS
20212, 20212: PASS
20220, 20220: PASS
20221, 20221: PASS
20222, 20222: PASS
21000, 21000: PASS
21001, 21001: PASS
21002, 21002: PASS
21010, 21010: PASS
21011, 21011: PASS
21012, 21012: PASS
21020, 21020: PASS
21021, 21021: PASS
21022, 21022: PASS
21100, 21100: PASS
21101, 21101: PASS
21102, 21102: PASS
21110, 21110: PASS
21111, 21111: PASS
21112, 21112: PASS
21120, 21120: PASS
21121, 21121: PASS
21122, 21122: PASS
21200, 21200: PASS
21201, 21201: PASS
21202, 21202: PASS
21210, 21210: PASS
21211, 21211: PASS
21212, 21212: PASS
21220, 21220: PASS
21221, 21221: PASS
21222, 21222: PASS
22000, 22000: PASS
22001, 22001: PASS
22002, 22002: PASS
22010, 22010: PASS
22011, 22011: PASS
22012, 22012: PASS
22020, 22020: PASS
22021, 22021: PASS
22022, 22022: PASS
22100, 22100: PASS
22101, 22101: PASS
22102, 22102: PASS
22110, 22110: PASS
22111, 22111: PASS
22112, 22112: PASS
22120, 22120: PASS
22121, 22121: PASS
22122, 22122: PASS
22200, 22200: PASS
22201, 22201: PASS
22202, 22202: PASS
22210, 22210: PASS
22211, 22211: PASS
22212, 22212: PASS
22220, 22220: PASS
22221, 22221: PASS
22222, 22222: PASS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a very good thought!
It is certain that the weight can be even more compacted. Like you said that we can treat the weight in a 3-value and compact it in a byte. However, what makes it less efficient is that it seems we can't find a suitable SIMD solution for this compaction. Unlike we can use _mm256_sign_epi16 in 2bit compaction, 3-value compaction seems not working with these strict align requirements.
So, I chose the 2bit compaction at the begining, also looking forward for a more efficient solution at the same time.


float* scale_ptr = (float*)((char*)i2_weight + n / 4);
for (int i=0; i<8; i++) {
scale_ptr[i] = i2_scale;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the same scale stored 8 times?

Copy link
Contributor Author

@Eddie-Wang1120 Eddie-Wang1120 Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that there is a alignment restrcition for gguf, which is 32bytes, so I stored 8 times by a float32 scale to fill the alignment. It can still work if I change it to scale_ptr[0] = i2_scale.

}

// 32B for scale
return nrow * row_size / 4 + 32;
Copy link
Collaborator

@compilade compilade Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the tensor-wide scales, even though the paper suggests using them, I wonder if using block scales would work too, so that it works better with the existing ggml type infrastructure. The scales could even all be made equal if that's a problem.

When (or if) packing 5 values per byte, since the row sizes from the bitnet models are usually not multiples of 5 (e.g. 1536, 2048), and since the 3B model uses a hidden_size of 3200 which isn't a multiple of 256, using blocks of 128 elements could work. Two groups of 64 elements, with each group having 12 bytes with 5 elements per byte, with 1 more byte with 4 elements, so 2*(12+1) = 26 bytes, and then a scale. If the scale is in f32, that would make this 1.875 bits per weight, while an f16 or bf16 scale would make this 1.75 bits per weight. (no scale would be 1.625 bpw, which is very close to the ideal of 1.5849625 bpw)

If packing only 4 ternary values per byte (as in i2_s), then using blocks of 128 elements with an f32 scale would make this 2.25 bits per weight, while using a scale in f16 or bf16 would make this 2.125 bits per weight, and would pretty much be like Q2_0 if that existed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, personally, I still think the tensor-wide scale is a better way for bitnet. At least for the 2bit compaction, 2.25 bpw means a around 10% model size waste, and it's kind of beyond what is acceptable.

Copy link
Collaborator

@compilade compilade Jun 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since 2 bits is already wasting 20% of the tensor size compared to the 1.6 bpw ideal for ternary, maybe there could be a way to still make this a block-wise quant (e.g. 4 elements per block of 1 byte), and have a row-wise/tensor-wise scale by somehow encoding it in unused bits in the first few blocks of each row? Might be a bad idea though, but I don't know yet why (maybe the overhead in recovering the scale?). (This would also require asserting a minimal row size in the quantize function.)

Because 4 ternary values fit in 7 bits (3^4 == 81 < 128), and you're already using a lookup table to expand the packed bits into 4 bytes, this could let the SIMD vec_dot stay pretty much identical to how it is now, except maybe it could include the scaling in its result?

Not sure yet how to pack the scale in i8_s, though, or some other way to let ggml_vec_dot_i2_i8_s have access to its scale.

Anyway, at least this gives some ideas to try eventually.

Copy link
Contributor Author

@Eddie-Wang1120 Eddie-Wang1120 Jun 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got your idea. Assume we have 128 continuous value, we can compact it into (128 / 4) * 7 = 224bits, and other 32bits will be a float32 scale, and still it's 2 bpw. The block size could be 28char(224bits) + 1float32(32bits) == 32bytes. One thing worries me a little is that we need to do some shifting to make the weight align so that we can index from the lookuptable, it may slow down the kernel, but it deserves give a try.

Copy link
Collaborator

@compilade compilade Jun 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block size could be 28char(224bits) + 1float32(32bits) == 32bytes. One thing worries me a little is that we need to do some shifting to make the weight align so that we can index from the lookuptable, it may slow down the kernel

Note that the unit for block sizes are elements, while type sizes are in bytes.

To keep the alignment, my suggestion was actually to keep using 8 bits per 4 elements (so that alignment remains easy), but also use the top bit of the first 16 or 32 bytes to store the scale. Only the lower (or upper? doesn't matter) 7 bits of the bytes would store 4 elements, using the fact that 3^4 == 81 < 128 == 2^7.

To go for maximum compactness, the same idea can be applied to 5-elements per bytes to achieve 1.625 bpw. The type size would be 13 bytes, the block size 64 elements. 12 bytes of 5 elements per byte and 1 byte of 4 elements, plus part of the scale. But this is more complicated, because 5 isn't a power of 2, so the SIMD vec_dot would need lots of non-trivial modifications, unlike with the other 2 bpw suggestion.

}

// ====================== "True" 2-bit (de)-quantization

void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
Expand Down Expand Up @@ -3726,6 +3791,85 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif

//====================================== I2 ===============================================

void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const uint8_t * restrict x = vx;
const int8_t * restrict y = vy;

UNUSED(bs);
UNUSED(bx);
UNUSED(by);
UNUSED(nrc);

// TODO
// #if defined(__AVX2__)
// __m256i accu = _mm256_setzero_si256();

// for (int i=0; i<n/32; i++) {
// const int8_t* w0 = (const int8_t *)(i2s_i8s + x[i*8 + 0]);
// const int8_t* w1 = (const int8_t *)(i2s_i8s + x[i*8 + 1]);
// const int8_t* w2 = (const int8_t *)(i2s_i8s + x[i*8 + 2]);
// const int8_t* w3 = (const int8_t *)(i2s_i8s + x[i*8 + 3]);
// const int8_t* w4 = (const int8_t *)(i2s_i8s + x[i*8 + 4]);
// const int8_t* w5 = (const int8_t *)(i2s_i8s + x[i*8 + 5]);
// const int8_t* w6 = (const int8_t *)(i2s_i8s + x[i*8 + 6]);
// const int8_t* w7 = (const int8_t *)(i2s_i8s + x[i*8 + 7]);

// __m256i xq8 = _mm256_set_epi8(
// w0[0], w0[1], w0[2], w0[3],
// w1[0], w1[1], w1[2], w1[3],
// w2[0], w2[1], w2[2], w2[3],
// w3[0], w3[1], w3[2], w3[3],
// w4[0], w4[1], w4[2], w4[3],
// w5[0], w5[1], w5[2], w5[3],
// w6[0], w6[1], w6[2], w6[3],
// w7[0], w7[1], w7[2], w7[3]
// );

// __m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i*32));

// __m128i hxq8 = _mm256_castsi256_si128(xq8);
// __m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
// __m128i hyq8 = _mm256_castsi256_si128(yq8);
// __m128i lyq8 = _mm256_extractf128_si256(yq8, 1);

// __m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
// __m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
// __m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
// __m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);

// __m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
// __m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);

// __m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(hzq16));
// __m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(hzq16, 1));
// __m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(lzq16));
// __m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(lzq16, 1));

// accu = _mm256_add_epi32(accu, hhzq32);
// accu = _mm256_add_epi32(accu, hlzq32);
// accu = _mm256_add_epi32(accu, llzq32);
// accu = _mm256_add_epi32(accu, lhzq32);
// }

// int sumi = hsum_i32_8(accu);
// *s = (float)sumi;
// #else

int sumi = 0;

for (int i = 0; i < n / 4; i++) {
const int8_t* weight = (const int8_t *)(i2s_i8s + x[i]);
sumi += (int)y[i*4+0] * weight[0];
sumi += (int)y[i*4+1] * weight[1];
sumi += (int)y[i*4+2] * weight[2];
sumi += (int)y[i*4+3] * weight[3];
}
*s = (float)sumi;
// #endif
}

void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
Expand Down Expand Up @@ -14367,6 +14511,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_I2_S:
// nothing to validate
break;
default:
Expand Down
3 changes: 3 additions & 0 deletions ggml-quants.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_i8_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, float* n);

// Dequantization
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
Expand Down Expand Up @@ -99,6 +100,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_i2_i8_s (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
Expand All @@ -121,6 +123,7 @@ size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_i2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

void iq2xs_init_impl(enum ggml_type type);
void iq2xs_free_impl(enum ggml_type type);
Expand Down
Loading