-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(ios): add missing ggml-metal-impl.h
- Loading branch information
Showing
3 changed files
with
252 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
#ifndef LM_GGML_METAL_IMPL | ||
#define LM_GGML_METAL_IMPL | ||
|
||
// kernel argument structs | ||
// | ||
// - element counters (e.g. ne00) typically use int32_t to reduce register usage | ||
// however, be careful from int overflows when using those in the kernel implementation | ||
// | ||
// - strides (e.g. nb00) use uint64_t | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne01; | ||
int32_t ne02; | ||
int32_t ne03; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne10; | ||
int32_t ne11; | ||
int32_t ne12; | ||
int32_t ne13; | ||
uint64_t nb10; | ||
uint64_t nb11; | ||
uint64_t nb12; | ||
uint64_t nb13; | ||
int32_t ne0; | ||
int32_t ne1; | ||
int32_t ne2; | ||
int32_t ne3; | ||
uint64_t nb0; | ||
uint64_t nb1; | ||
uint64_t nb2; | ||
uint64_t nb3; | ||
int32_t dim; | ||
} lm_ggml_metal_kargs_concat; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne01; | ||
int32_t ne02; | ||
int32_t ne03; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne10; | ||
int32_t ne11; | ||
int32_t ne12; | ||
int32_t ne13; | ||
uint64_t nb10; | ||
uint64_t nb11; | ||
uint64_t nb12; | ||
uint64_t nb13; | ||
int32_t ne0; | ||
int32_t ne1; | ||
int32_t ne2; | ||
int32_t ne3; | ||
uint64_t nb0; | ||
uint64_t nb1; | ||
uint64_t nb2; | ||
uint64_t nb3; | ||
uint64_t offs; | ||
} lm_ggml_metal_kargs_bin; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne01; | ||
int32_t ne02; | ||
int32_t ne03; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne0; | ||
int32_t ne1; | ||
int32_t ne2; | ||
int32_t ne3; | ||
uint64_t nb0; | ||
uint64_t nb1; | ||
uint64_t nb2; | ||
uint64_t nb3; | ||
} lm_ggml_metal_kargs_repeat; | ||
|
||
typedef struct { | ||
int64_t ne00; | ||
int64_t ne01; | ||
int64_t ne02; | ||
int64_t ne03; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int64_t ne0; | ||
int64_t ne1; | ||
int64_t ne2; | ||
int64_t ne3; | ||
uint64_t nb0; | ||
uint64_t nb1; | ||
uint64_t nb2; | ||
uint64_t nb3; | ||
} lm_ggml_metal_kargs_cpy; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne01; | ||
int32_t ne02; | ||
int32_t ne03; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne0; | ||
int32_t ne1; | ||
int32_t ne2; | ||
int32_t ne3; | ||
uint64_t nb0; | ||
uint64_t nb1; | ||
uint64_t nb2; | ||
uint64_t nb3; | ||
int32_t n_past; | ||
int32_t n_dims; | ||
int32_t n_ctx_orig; | ||
float freq_base; | ||
float freq_scale; | ||
float ext_factor; | ||
float attn_factor; | ||
float beta_fast; | ||
float beta_slow; | ||
} lm_ggml_metal_kargs_rope; | ||
|
||
typedef struct { | ||
int32_t ne01; | ||
int32_t ne02; | ||
int32_t ne03; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne11; | ||
int32_t ne_12_2; // assume K and V are same shape | ||
int32_t ne_12_3; | ||
uint64_t nb_12_1; | ||
uint64_t nb_12_2; | ||
uint64_t nb_12_3; | ||
uint64_t nb31; | ||
int32_t ne1; | ||
int32_t ne2; | ||
float scale; | ||
float max_bias; | ||
float m0; | ||
float m1; | ||
uint16_t n_head_log2; | ||
float logit_softcap; | ||
} lm_ggml_metal_kargs_flash_attn_ext; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne02; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne12; | ||
uint64_t nb10; | ||
uint64_t nb11; | ||
uint64_t nb12; | ||
uint64_t nb13; | ||
int32_t ne0; | ||
int32_t ne1; | ||
int16_t r2; | ||
int16_t r3; | ||
} lm_ggml_metal_kargs_mul_mm; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne01; | ||
int32_t ne02; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
uint64_t nb03; | ||
int32_t ne10; | ||
int32_t ne11; | ||
int32_t ne12; | ||
uint64_t nb10; | ||
uint64_t nb11; | ||
uint64_t nb12; | ||
uint64_t nb13; | ||
int32_t ne0; | ||
int32_t ne1; | ||
int16_t r2; | ||
int16_t r3; | ||
} lm_ggml_metal_kargs_mul_mv; | ||
|
||
typedef struct { | ||
int32_t nei0; | ||
int32_t nei1; | ||
uint64_t nbi1; | ||
int32_t ne00; | ||
int32_t ne02; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
int32_t ne11; | ||
int32_t ne12; | ||
int32_t ne13; | ||
uint64_t nb10; | ||
uint64_t nb11; | ||
uint64_t nb12; | ||
int32_t ne0; | ||
int32_t ne1; | ||
} lm_ggml_metal_kargs_mul_mm_id; | ||
|
||
typedef struct { | ||
int32_t nei0; | ||
int32_t nei1; | ||
uint64_t nbi1; | ||
int32_t ne00; | ||
int32_t ne01; | ||
int32_t ne02; | ||
uint64_t nb00; | ||
uint64_t nb01; | ||
uint64_t nb02; | ||
int32_t ne10; | ||
int32_t ne11; | ||
int32_t ne12; | ||
int32_t ne13; | ||
uint64_t nb10; | ||
uint64_t nb11; | ||
uint64_t nb12; | ||
int32_t ne0; | ||
int32_t ne1; | ||
uint64_t nb1; | ||
} lm_ggml_metal_kargs_mul_mv_id; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne00_4; | ||
uint64_t nb01; | ||
float eps; | ||
} lm_ggml_metal_kargs_norm; | ||
|
||
typedef struct { | ||
int32_t ne00; | ||
int32_t ne00_4; | ||
uint64_t nb01; | ||
float eps; | ||
} lm_ggml_metal_kargs_rms_norm; | ||
|
||
#endif // LM_GGML_METAL_IMPL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732158453708-0.8549147503631913/node | ||
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1732160897686-0.7191373753799657/node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters