Skip to content

Commit

Permalink
add [[unroll]] and remove unnecessary conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
uniartisan committed Dec 16, 2024
1 parent 64c16c4 commit 6ea605d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#version 450

#extension GL_EXT_control_flow_attributes : require

#define BLOCK_SIZE 64
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

Expand Down Expand Up @@ -29,12 +31,12 @@ void main() {
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;

if (tid >= head_size || batch_id >= B || head_id >= H) {
if (batch_id >= B || head_id >= H) {
return;
}

A_TYPE state[BLOCK_SIZE];
for (uint i = 0; i < head_size; i++) {
[[unroll]] for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid];
}
Expand All @@ -56,7 +58,7 @@ void main() {
const A_TYPE v_val = v[t];
A_TYPE y = 0.0;

for (uint j = 0; j < head_size; j += 4) {
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
Expand All @@ -78,7 +80,7 @@ void main() {
dst[t] = y;
}

for (uint i = 0; i < head_size; i++) {
[[unroll]] for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ i * head_size + tid] = state[i];
}
Expand Down

0 comments on commit 6ea605d

Please sign in to comment.