Skip to content

Commit

Permalink
[BUG][CUDA] Revert CUDA LayerNorm Grid.x Grid.y to fix >65535 bug on …
Browse files Browse the repository at this point in the history
…large layernorm
  • Loading branch information
doxutx committed Mar 26, 2024
1 parent 1e2b6a6 commit f453576
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions source/tnn/device/cuda/acc/cuda_layer_norm_layer_acc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ template<typename T>
__global__ void ln_mul_add_kernel(const T *input, T *output, const T *scale, const T *bias,
const LNFloat2 *mean_var,
const int count, const float eps) {
int offset = blockIdx.x * blockDim.x + threadIdx.x;
int total_offset = blockIdx.y * count + offset;
int offset = blockIdx.y * blockDim.y + threadIdx.x;
int total_offset = blockIdx.x * count + offset;
if (offset < count) {
const float* mean_var_float = reinterpret_cast<const float*>(mean_var);
float mean = mean_var_float[blockIdx.y * 2 + 0] / float(count);
float var = mean_var_float[blockIdx.y * 2 + 1] / float(count) - mean * mean;
float mean = mean_var_float[blockIdx.x * 2 + 0] / float(count);
float var = mean_var_float[blockIdx.x * 2 + 1] / float(count) - mean * mean;
var = 1.0 / sqrt(var + eps);
float k = float(scale[offset]) * var;
float b = - mean * k + float(bias[offset]);
Expand Down Expand Up @@ -126,8 +126,8 @@ Status CudaLayerNormLayerAcc::Forward(const std::vector<Blob *> &inputs, const s
const int THREAD_PER_BLOCK = 1024;
int num_blocks = (channel_area - 1) / THREAD_PER_BLOCK + 1;
dim3 griddim;
griddim.x = num_blocks;
griddim.y = channels; // batch_size
griddim.x = channels; // batch_size
griddim.y = num_blocks;

// Re-Allocate Temp Buffer if size of existing one is not enough.
ResizeTempBuf(0, sizeof(LNFloat2) * channels); // Buffer for stored mean & var
Expand Down

0 comments on commit f453576

Please sign in to comment.