Skip to content

Commit

Permalink
llamafile : improve sgemm.cpp (ggerganov#6796)
Browse files Browse the repository at this point in the history
* llamafile : improve sgemm.cpp

- Re-enable by default
- Fix issue described in ggerganov#6716
- Make code more abstract, elegant, and maintainable
- Faster handling of weirdly shaped `m` an `n` edge cases

* Address review comments

* Help clang produce fma instructions

* Address review comments
  • Loading branch information
jart authored Apr 22, 2024
1 parent e931888 commit 192090b
Show file tree
Hide file tree
Showing 4 changed files with 412 additions and 573 deletions.
16 changes: 5 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,11 @@ else()
set(LLAMA_METAL_DEFAULT OFF)
endif()

# TODO: fix this for Android CI
# https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191
#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
# set(LLAMA_LLAMAFILE_DEFAULT OFF)
#else()
# set(LLAMA_LLAMAFILE_DEFAULT ON)
#endif()

# TODO: temporary disable until MoE is fixed
# https://github.com/ggerganov/llama.cpp/pull/6716
set(LLAMA_LLAMAFILE_DEFAULT OFF)
if (CMAKE_SYSTEM_NAME MATCHES "ANDROID")
set(LLAMA_LLAMAFILE_DEFAULT OFF)
else()
set(LLAMA_LLAMAFILE_DEFAULT ON)
endif()

# general
option(BUILD_SHARED_LIBS "build shared libraries" OFF)
Expand Down
4 changes: 0 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,6 @@ ifdef LLAMA_OPENBLAS
MK_LDFLAGS += $(shell pkg-config --libs openblas)
endif # LLAMA_OPENBLAS

# TODO: temporary disable until MoE is fixed
# https://github.com/ggerganov/llama.cpp/pull/6716
LLAMA_NO_LLAMAFILE := 1

ifndef LLAMA_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
OBJS += sgemm.o
Expand Down
8 changes: 3 additions & 5 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10825,7 +10825,7 @@ static void ggml_compute_forward_mul_mat(
#endif

#if GGML_USE_LLAMAFILE
if (nb10 == ggml_type_size(src1->type)) {
if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
Expand Down Expand Up @@ -10878,15 +10878,13 @@ UseGgmlGemm1:;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);

#if GGML_USE_LLAMAFILE
if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) {
if (src1->type != vec_dot_type) {
for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)wdata + ggml_row_size(vec_dot_type,
nb12/ggml_type_size(src1->type)*i12 +
nb13/ggml_type_size(src1->type)*i13),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
Expand Down
Loading

0 comments on commit 192090b

Please sign in to comment.