Skip to content

Commit

Permalink
minor : style + indentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 30, 2024
1 parent 7004323 commit 59dc2e7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
12 changes: 6 additions & 6 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1899,12 +1899,12 @@ extern "C" {

GGML_API struct ggml_tensor * ggml_rwkv_wkv(
struct ggml_context * ctx,
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state);
struct ggml_tensor * k,
struct ggml_tensor * v,
struct ggml_tensor * r,
struct ggml_tensor * tf,
struct ggml_tensor * td,
struct ggml_tensor * state);

// custom operators

Expand Down
43 changes: 24 additions & 19 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9383,24 +9383,25 @@ static struct ggml_tensor * llm_build_mamba(
return cur;
}

static struct ggml_tensor * llm_build_time_mix_rwkv6(
struct llama_context & lctx,
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state) {
size_t n_embed = cur->ne[0];
static struct ggml_tensor * llm_build_rwkv6_time_mix(
struct llama_context & lctx,
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state) {
size_t n_embed = cur->ne[0];
size_t n_seq_tokens = cur->ne[1];
size_t n_seqs = cur->ne[2];
size_t head_size = layer->time_mix_first->ne[0];
size_t n_seqs = cur->ne[2];

size_t head_size = layer->time_mix_first->ne[0];
size_t head_count = layer->time_mix_first->ne[1];

size_t n_tokens = n_seqs * n_seq_tokens;

struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);

sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens);
sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens);
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);

struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);
Expand Down Expand Up @@ -9498,20 +9499,23 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
ggml_mul_mat(ctx, layer->time_mix_decay_w1, xw)
)
);

w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);

k = ggml_transpose(ctx, k);
v = ggml_transpose(ctx, v);
r = ggml_transpose(ctx, r);

struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));

// group norm with head_count groups
cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
cur = ggml_norm(ctx, cur, 64e-5f);

// Convert back to regular vectors.
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
Expand All @@ -9522,12 +9526,12 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
}

static struct ggml_tensor * llm_build_channel_mix_rwkv6(
struct llama_context & lctx,
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev) {
static struct ggml_tensor * llm_build_rwkv6_channel_mix(
struct llama_context & lctx,
struct ggml_context * ctx,
const struct llama_layer * layer,
struct ggml_tensor * cur,
struct ggml_tensor * x_prev) {
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
struct ggml_tensor * xk = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_k), cur);
struct ggml_tensor * xr = ggml_add(ctx, ggml_mul(ctx, sx, layer->channel_mix_lerp_r), cur);
Expand All @@ -9540,6 +9544,7 @@ static struct ggml_tensor * llm_build_channel_mix_rwkv6(
llm_build_lora_mm(lctx, ctx, layer->channel_mix_key, xk)
)
);

return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
}

Expand Down Expand Up @@ -15111,7 +15116,7 @@ struct llm_build_context {
1
);

cur = ggml_add(ctx0, cur, llm_build_time_mix_rwkv6(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
ggml_build_forward_expand(gf, cur);
ggml_build_forward_expand(
gf,
Expand All @@ -15134,7 +15139,7 @@ struct llm_build_context {
ggml_view_3d(ctx0, x_norm_ffn, n_embd, n_seq_tokens - 1, n_seqs, x_norm_ffn->nb[1], x_norm_ffn->nb[2], 0),
1
);
cur = ggml_add(ctx0, cur, llm_build_channel_mix_rwkv6(lctx, ctx0, layer, x_norm_ffn, x_prev));
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, x_norm_ffn, x_prev));
ggml_build_forward_expand(gf, cur);

struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
Expand Down

0 comments on commit 59dc2e7

Please sign in to comment.