Skip to content

Commit

Permalink
modify unnecessary calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
A3shTnT committed Dec 4, 2024
1 parent 828e4f7 commit e52a22d
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions ggml/src/ggml-cuda/ssm_scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,25 @@ __global__ void __launch_bounds__(splitD, 2)
__syncthreads();

for (int i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + wid * warpSize + wtid];
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + wid * warpSize + wtid] * dt_soft_plus;
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (int j = 0; j < N; j++) {
float state = (smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] *
expf(dt_soft_plus * smem_A[(wid * warpSize + wtid) * stride_sA + j])) +
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
if (i == L - 1) {
s_block[(wid * warpSize + wtid) * stride_s + j] = state;
s_block[tid * stride_s + j] = state;
} else {
smem_s0[(wid * warpSize + wtid) * stride_ss0 + j] = state;
smem_s0[tid * stride_ss0 + j] = state;
}
}
__syncthreads();
y_block[i * stride_y + wid * warpSize + wtid] = sumf;
y_block[i * stride_y + tid] = sumf;
}
}

Expand Down

0 comments on commit e52a22d

Please sign in to comment.