Skip to content

Commit

Permalink
fix(ios): add missing ggml-metal-impl.h
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 21, 2024
1 parent 9ccaad9 commit 96b6dd7
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 1 deletion.
249 changes: 249 additions & 0 deletions cpp/ggml-metal-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#ifndef LM_GGML_METAL_IMPL
#define LM_GGML_METAL_IMPL

// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
// however, be careful from int overflows when using those in the kernel implementation
//
// - strides (e.g. nb00) use uint64_t

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t dim;
} lm_ggml_metal_kargs_concat;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
} lm_ggml_metal_kargs_bin;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} lm_ggml_metal_kargs_repeat;

typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} lm_ggml_metal_kargs_cpy;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t n_past;
int32_t n_dims;
int32_t n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
} lm_ggml_metal_kargs_rope;

typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
float scale;
float max_bias;
float m0;
float m1;
uint16_t n_head_log2;
float logit_softcap;
} lm_ggml_metal_kargs_flash_attn_ext;

typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} lm_ggml_metal_kargs_mul_mm;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} lm_ggml_metal_kargs_mul_mv;

typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
} lm_ggml_metal_kargs_mul_mm_id;

typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
uint64_t nb1;
} lm_ggml_metal_kargs_mul_mv_id;

typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} lm_ggml_metal_kargs_norm;

typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} lm_ggml_metal_kargs_rms_norm;

#endif // LM_GGML_METAL_IMPL
2 changes: 1 addition & 1 deletion example/ios/.xcode.env.local
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732158453708-0.8549147503631913/node
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732160897686-0.7191373753799657/node
2 changes: 2 additions & 0 deletions scripts/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cp ./llama.cpp/ggml/include/ggml-opt.h ./cpp/ggml-opt.h
cp ./llama.cpp/ggml/include/ggml-metal.h ./cpp/ggml-metal.h

cp ./llama.cpp/ggml/src/ggml-metal/ggml-metal.m ./cpp/ggml-metal.m
cp ./llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h ./cpp/ggml-metal-impl.h

cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c ./cpp/ggml-cpu.c
cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp ./cpp/ggml-cpu.cpp
Expand Down Expand Up @@ -90,6 +91,7 @@ files_add_lm_prefix=(
"./cpp/ggml-opt.cpp"
"./cpp/ggml-metal.h"
"./cpp/ggml-metal.m"
"./cpp/ggml-metal-impl.h"
"./cpp/ggml-quants.h"
"./cpp/ggml-quants.c"
"./cpp/ggml-alloc.h"
Expand Down

0 comments on commit 96b6dd7

Please sign in to comment.