Skip to content

Commit

Permalink
llama: rwkv6: Use ggml_norm instead of ggml_group_norm
Browse files Browse the repository at this point in the history
Co-authored-by: compilade <[email protected]>
  • Loading branch information
MollySophia and compilade authored Aug 25, 2024
1 parent e29b446 commit c1a3a9c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9496,10 +9496,10 @@ static struct ggml_tensor * llm_build_time_mix_rwkv6(
cur = ggml_view_1d(ctx, wkv_output, n_embed * n_tokens, 0);
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embed * head_size * n_seqs, n_embed * n_tokens * sizeof(float));

// ggml_group_norm considers groups in the third dimension.
cur = ggml_reshape_4d(ctx, cur, n_embed / head_count, 1, head_count, n_tokens);
cur = ggml_group_norm(ctx, cur, head_count, 64e-5f);
// Convert back to a regular vector.
// group norm with head_count groups
cur = ggml_reshape_3d(ctx, cur, n_embed / head_count, head_count, n_tokens);
cur = ggml_norm(ctx, cur, 64e-5f);
// Convert back to regular vectors.
cur = ggml_reshape_2d(ctx, cur, n_embed, n_tokens);
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);

Expand Down

0 comments on commit c1a3a9c

Please sign in to comment.