Skip to content

Commit

Permalink
fix(ios): add missing ggml-metal-impl.h
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 21, 2024
1 parent 053ef40 commit 9c8dc01
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 0 deletions.
249 changes: 249 additions & 0 deletions cpp/ggml-metal-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
#ifndef WSP_GGML_METAL_IMPL
#define WSP_GGML_METAL_IMPL

// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
// however, be careful from int overflows when using those in the kernel implementation
//
// - strides (e.g. nb00) use uint64_t

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t dim;
} wsp_ggml_metal_kargs_concat;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t offs;
} wsp_ggml_metal_kargs_bin;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} wsp_ggml_metal_kargs_repeat;

typedef struct {
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int64_t ne0;
int64_t ne1;
int64_t ne2;
int64_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} wsp_ggml_metal_kargs_cpy;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
int32_t n_past;
int32_t n_dims;
int32_t n_ctx_orig;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
} wsp_ggml_metal_kargs_rope;

typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
uint64_t nb_12_1;
uint64_t nb_12_2;
uint64_t nb_12_3;
uint64_t nb31;
int32_t ne1;
int32_t ne2;
float scale;
float max_bias;
float m0;
float m1;
uint16_t n_head_log2;
float logit_softcap;
} wsp_ggml_metal_kargs_flash_attn_ext;

typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} wsp_ggml_metal_kargs_mul_mm;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
} wsp_ggml_metal_kargs_mul_mv;

typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
} wsp_ggml_metal_kargs_mul_mm_id;

typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
uint64_t nb1;
} wsp_ggml_metal_kargs_mul_mv_id;

typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} wsp_ggml_metal_kargs_norm;

typedef struct {
int32_t ne00;
int32_t ne00_4;
uint64_t nb01;
float eps;
} wsp_ggml_metal_kargs_rms_norm;

#endif // WSP_GGML_METAL_IMPL
2 changes: 2 additions & 0 deletions scripts/bootstrap.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cp ./whisper.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h ./cpp/ggml-cpu-quants.h
cp ./whisper.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c ./cpp/ggml-cpu-quants.c

cp ./whisper.cpp/ggml/src/ggml-metal/ggml-metal.m ./cpp/ggml-metal.m
cp ./whisper.cpp/ggml/src/ggml-metal/ggml-metal-impl.h ./cpp/ggml-metal-impl.h

cp ./whisper.cpp/include/whisper.h ./cpp/whisper.h
cp ./whisper.cpp/src/whisper.cpp ./cpp/whisper.cpp
Expand All @@ -51,6 +52,7 @@ files=(
"./cpp/ggml-cpp.h"
"./cpp/ggml-metal.h"
"./cpp/ggml-metal.m"
"./cpp/ggml-metal-impl.h"
"./cpp/ggml-quants.h"
"./cpp/ggml-quants.c"
"./cpp/ggml-alloc.h"
Expand Down

0 comments on commit 9c8dc01

Please sign in to comment.