Skip to content

Commit

Permalink
ggml : fix quant dot product with odd number of blocks (#8549)
Browse files Browse the repository at this point in the history
* ggml : fix iq4_nl dot product with odd number of blocks

* ggml : fix odd blocks for ARM_NEON (#8556)

* ggml : fix iq4_nl dot product with odd number of blocks

* ggml : fix q4_1

* ggml : fix q5_0

* ggml : fix q5_1

* ggml : fix iq4_nl metal

ggml-ci

* ggml : fix q4_0

* ggml : fix q8_0

ggml-ci

* ggml : remove special Q4_0 code for first 2 blocks

* ggml : fix sumf redefinition

---------

Co-authored-by: slaren <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
slaren and ggerganov authored Jul 19, 2024
1 parent 57b1d4f commit 87e397d
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 501 deletions.
4 changes: 0 additions & 4 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1786,10 +1786,6 @@ static enum ggml_status ggml_metal_graph_compute(
}
};

if (ggml_is_quantized(src0t)) {
GGML_ASSERT(ne00 >= nth0*nth1);
}

[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4757,7 +4757,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];

for (int row = 0; row < 2; ++row) {
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {

device const block_iq4_nl & xb = x[row*nb + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
Expand Down Expand Up @@ -4789,7 +4789,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
yb += 16 * QK4_NL;
}

for (int row = 0; row < 2; ++row) {
for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
Expand Down
Loading

0 comments on commit 87e397d

Please sign in to comment.