diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index de098a9ee0c19..974e69954e8d0 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -234,19 +234,22 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] int num_tokens = input.numel() / hidden_size; int vec_size = 16 / input.element_size(); int vec_hidden_size = hidden_size / vec_size; + bool can_run_vectorize = (hidden_size%vec_size) == 0:true?false; dim3 grid(num_tokens); - dim3 block(std::min(vec_hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); #ifdef __HIP__MI300_MI250__ - if (vec_size % 8 == 0) { + if (vec_size % 8 == 0 && can_run_vectorize) { + dim3 block(std::min(vec_hidden_size, 1024)); LAUNCH_RMS_NORM(8); } else { + dim3 block(std::min(hidden_size, 1024)); LAUNCH_RMS_NORM(0); } #else + dim3 block(std::min(hidden_size, 1024)); LAUNCH_RMS_NORM(0); #endif }