Skip to content

Commit

Permalink
applying the missing formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Nov 27, 2024
1 parent feae871 commit e867af8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
29 changes: 15 additions & 14 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {

// Launch activation and gating kernel.
#ifdef USE_ROCM
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \
vllm::scaled_act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<c10::Float8_e4m3fnuz>(), \
input.data_ptr<scalar_t>(), d, \
1.0 / (*scale.data_ptr<float>())); \
});
#endif

void silu_and_mul(torch::Tensor& out, // [..., d]
Expand Down
2 changes: 1 addition & 1 deletion csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
LAUNCH_RMS_NORM(0);
}
#else
LAUNCH_RMS_NORM(0);
LAUNCH_RMS_NORM(0);
#endif
}

Expand Down

0 comments on commit e867af8

Please sign in to comment.