Skip to content

Commit

Permalink
llama : rename n_embed to n_embd in rwkv6_time_mix
Browse files Browse the repository at this point in the history
This commit renames n_embed to n_embd in llm_build_rwkv6_time_mix.

The motivation for this change is consistency with the other rwkv6
functions like build_rwkv6 (and other parts of the code base).
  • Loading branch information
danbev committed Sep 16, 2024
1 parent 90a2fff commit 0bfa0df
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9417,7 +9417,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
struct ggml_tensor * cur,
struct ggml_tensor * x_prev,
struct ggml_tensor ** wkv_state) {
size_t n_embed = cur->ne[0];
size_t n_embd = cur->ne[0];
size_t n_seq_tokens = cur->ne[1];
size_t n_seqs = cur->ne[2];

Expand All @@ -9428,8 +9428,8 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(

struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);

sx = ggml_reshape_2d(ctx, sx, n_embed, n_tokens);
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
sx = ggml_reshape_2d(ctx, sx, n_embd, n_tokens);
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);

struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_x), cur);

Expand All @@ -9454,11 +9454,11 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
xxx
);

struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], 0);
struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * sizeof(float));
struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 2 * sizeof(float));
struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 3 * sizeof(float));
struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embed, n_tokens, xxx->nb[1], n_embed * n_tokens * 4 * sizeof(float));
struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));

struct ggml_tensor * xw = ggml_add(
ctx,
Expand Down Expand Up @@ -9527,7 +9527,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
)
);

w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embed));
w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);

Expand All @@ -9536,21 +9536,21 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
r = ggml_transpose(ctx, r);

struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));

// group norm with head_count groups
cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
cur = ggml_norm(ctx, cur, 64e-5f);

// Convert back to regular vectors.
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);

cur = ggml_mul(ctx, cur, g);
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);

return ggml_reshape_3d(ctx, cur, n_embed, n_seq_tokens, n_seqs);
return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
}

static struct ggml_tensor * llm_build_rwkv6_channel_mix(
Expand Down

0 comments on commit 0bfa0df

Please sign in to comment.