Skip to content

Commit

Permalink
Fix correctness regression in Llama-3.2-90B-Vision-Instruct-FP8-KV test
Browse files Browse the repository at this point in the history
  • Loading branch information
wunhuang committed Nov 27, 2024
1 parent 2302ad6 commit 15029b8
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,22 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
int num_tokens = input.numel() / hidden_size;
int vec_size = 16 / input.element_size();
int vec_hidden_size = hidden_size / vec_size;
bool can_run_vectorize = (hidden_size%vec_size) == 0:true?false;

dim3 grid(num_tokens);
dim3 block(std::min(vec_hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

#ifdef __HIP__MI300_MI250__
if (vec_size % 8 == 0) {
if (vec_size % 8 == 0 && can_run_vectorize) {
dim3 block(std::min(vec_hidden_size, 1024));
LAUNCH_RMS_NORM(8);
} else {
dim3 block(std::min(hidden_size, 1024));
LAUNCH_RMS_NORM(0);
}
#else
dim3 block(std::min(hidden_size, 1024));
LAUNCH_RMS_NORM(0);
#endif
}
Expand Down

0 comments on commit 15029b8

Please sign in to comment.