From 843f0450e5fa16920425d9f6a6b8344c27532114 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Fri, 2 Feb 2024 16:03:17 -0500 Subject: [PATCH] Implement rotary embedding kernel --- kernels/rotary_embedding_kernel.cu | 105 +++++++++++++++++++++ src/backend/cache.rs | 8 +- src/backend/layers.rs | 142 +++++++++++++++++++++++++++-- src/backend/mod.rs | 24 +++-- 4 files changed, 261 insertions(+), 18 deletions(-) create mode 100644 kernels/rotary_embedding_kernel.cu diff --git a/kernels/rotary_embedding_kernel.cu b/kernels/rotary_embedding_kernel.cu new file mode 100644 index 0000000..ea0361e --- /dev/null +++ b/kernels/rotary_embedding_kernel.cu @@ -0,0 +1,105 @@ +#include + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) +{ + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} + +extern "C" __global__ void rotary_embedding_kernel_neox( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + rotary_embedding_kernel(positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); +} + +extern "C" __global__ void rotary_embedding_kernel_normal( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + rotary_embedding_kernel(positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); +} diff --git a/src/backend/cache.rs b/src/backend/cache.rs index 9e9c900..750bf8a 100644 --- a/src/backend/cache.rs +++ b/src/backend/cache.rs @@ -7,6 +7,7 @@ use candle_core::{ }, DType, Device, IndexOp, Storage, Tensor, }; +use either::Either; use half::{bf16, f16}; use crate::{ @@ -117,7 +118,7 @@ pub fn reshape_and_cache( let kernel = try_api!(get_or_load_func( RESHAPE_AND_CACHE_PTX, RESHAPE_AND_CACHE_KERNEL, - key.dtype(), + Either::Left(key.dtype()), dev )); @@ -165,6 +166,9 @@ pub fn copy_blocks( ))); } let num_layers: u32 = key_caches.len().try_into().unwrap(); + if num_layers == 0 { + return Ok(()); + } let mut key_cache_ptrs = Vec::new(); key_cache_ptrs.reserve_exact(num_layers as usize); @@ -238,7 +242,7 @@ pub fn copy_blocks( let kernel = try_api!(get_or_load_func( COPY_BLOCKS_PTX, COPY_BLOCKS_KERNEL, - key_caches.first().unwrap().dtype(), + Either::Left(key_caches.first().unwrap().dtype()), dev, )); diff --git a/src/backend/layers.rs b/src/backend/layers.rs index 39cb3fd..1b6ac80 100644 --- a/src/backend/layers.rs +++ b/src/backend/layers.rs @@ -1,12 +1,140 @@ use candle_core::Tensor; +use either::Either; + +use crate::{ + backend::{get_or_load_func, ROTARY_EMBDEDDING_KERNEL, ROTARY_EMBDEDDING_PTX}, + try_api, +}; + +use super::dispatch_get_cuda_pointer; pub fn rotary_embedding( - _positions: Tensor, - _query: &mut Tensor, - _key: &mut Tensor, - _head_size: usize, - _cos_sin_cache: Tensor, - _is_neox: bool, + positions: Tensor, + query: &mut Tensor, + key: &mut Tensor, + head_size: usize, + cos_sin_cache: Tensor, + is_neox: bool, ) { - todo!() + let positions_dev = positions.device(); + let Device::Cuda(dev) = positions_dev else { + panic!("Expected the positions to be on a CUDA device.") + }; + + if !query.device().same_device(positions.device()) { + return Err(APIError::new(format!( + "`query` and `positions` have different devices, got {:?} and {:?} respectively.", + query.device(), + positions.device() + ))); + } + + if !key.device().same_device(positions.device()) { + return Err(APIError::new(format!( + "`key` and `positions` have different devices, got {:?} and {:?} respectively.", + key.device(), + positions.device() + ))); + } + + if !cos_sin_cache.device().same_device(positions.device()) { + return Err(APIError::new(format!( + "`cos_sin_cache` and `positions` have different devices, got {:?} and {:?} respectively.", + cos_sin_cache.device(), + positions.device() + ))); + } + + let num_tokens = query.shape().elem_count() / query.shape().dims().last().unwrap(); + let rot_dim = cos_sin_cache.shape().dims().get(1).unwrap(); + let num_heads = query.shape().dims().last().unwrap() / head_size; + let num_kv_heads = key.shape().dims().last().unwrap() / head_size; + let query_stride = query.stride().get(key.stride().len() - 2).unwrap(); + let key_stride = key.stride().get(key.stride().len() - 2).unwrap(); + + let launch_conf = LaunchConfig { + grid_dim: (num_tokens.try_into().unwrap(), 1u32, 1u32), + block_dim: ( + 512.min((num_heads * rot_dim / 2).try_into().unwrap()), + 1u32, + 1u32, + ), + shared_mem_bytes: 0, + }; + + let positions_ptr = dispatch_get_cuda_pointer(positions); + let key_ptr = dispatch_get_cuda_pointer(key); + let query_ptr = dispatch_get_cuda_pointer(query); + let cos_sin_cache_ptr = dispatch_get_cuda_pointer(cos_sin_cache); + + let stream = try_api!(dev.fork_default_stream()); + + let kernel = if is_neox { + try_api!(get_or_load_func( + ROTARY_EMBDEDDING_PTX, + ROTARY_EMBDEDDING_KERNEL, + Either::Right("_neox"), + dev + )) + } else { + try_api!(get_or_load_func( + ROTARY_EMBDEDDING_PTX, + ROTARY_EMBDEDDING_KERNEL, + Either::Right("_normal"), + dev + )) + }; + + try_api!(unsafe { + kernel.launch_on_stream( + &stream, + launch_conf, + ( + positions_ptr, + query_ptr, + key_ptr, + cos_sin_cache_ptr, + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size, + ), + ) + }); + + /* + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), + "rotary_embedding", + [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + });*/ } diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 51dbd24..81f5f97 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -10,20 +10,25 @@ const RESHAPE_AND_CACHE_PTX: &str = "kernels/reshape_and_cache_kernel.ptx"; const RESHAPE_AND_CACHE_KERNEL: &str = "reshape_and_cache_kernel"; +const ROTARY_EMBDEDDING_PTX: &str = "kernels/rotary_embedding_kernel.ptx"; + +const ROTARY_EMBDEDDING_KERNEL: &str = "rotary_embedding_kernel"; + pub fn get_or_load_func( ptx_file: &'static str, kernel_base: &str, - dtype: DType, + suffix: Either, device: &CudaDevice, ) -> Result { - let suffix = match dtype { - DType::U8 => "_u8", - DType::U32 => "_u32", - DType::I64 => "_i64", - DType::BF16 => "_bf16", - DType::F16 => "_f16", - DType::F32 => "_f32", - DType::F64 => "_f64", + let suffix = match suffix { + Either::Left(DType::U8) => "_u8", + Either::Left(DType::U32) => "_u32", + Either::Left(DType::I64) => "_i64", + Either::Left(DType::BF16) => "_bf16", + Either::Left(DType::F16) => "_f16", + Either::Left(DType::F32) => "_f32", + Either::Left(DType::F64) => "_f64", + Either::Right(data) => data, }; let kernel = kernel_base.to_owned() + suffix; device @@ -82,6 +87,7 @@ use candle_core::{ cuda_backend::cudarc::driver::{CudaFunction, DeviceRepr}, CudaDevice, DType, }; +use either::Either; pub use layers::*; pub use paged_attention::*; pub use std::ops::Deref;