Skip to content

Commit

Permalink
Implement rotary embedding kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Feb 2, 2024
1 parent 6c343d4 commit 843f045
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 18 deletions.
105 changes: 105 additions & 0 deletions kernels/rotary_embedding_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#include <stdint.h>

#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif

template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = VLLM_LDG(cos_ptr + x_index);
sin = VLLM_LDG(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = VLLM_LDG(sin_ptr + x_index / 2);
}

const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}

template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;

const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}

const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
}

extern "C" __global__ void rotary_embedding_kernel_neox(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
rotary_embedding_kernel<scalar_t, true>(positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size);
}

extern "C" __global__ void rotary_embedding_kernel_normal(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
rotary_embedding_kernel<scalar_t, true>(positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size);
}
8 changes: 6 additions & 2 deletions src/backend/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use candle_core::{
},
DType, Device, IndexOp, Storage, Tensor,
};
use either::Either;
use half::{bf16, f16};

use crate::{
Expand Down Expand Up @@ -117,7 +118,7 @@ pub fn reshape_and_cache(
let kernel = try_api!(get_or_load_func(
RESHAPE_AND_CACHE_PTX,
RESHAPE_AND_CACHE_KERNEL,
key.dtype(),
Either::Left(key.dtype()),
dev
));

Expand Down Expand Up @@ -165,6 +166,9 @@ pub fn copy_blocks(
)));
}
let num_layers: u32 = key_caches.len().try_into().unwrap();
if num_layers == 0 {
return Ok(());
}

let mut key_cache_ptrs = Vec::new();
key_cache_ptrs.reserve_exact(num_layers as usize);
Expand Down Expand Up @@ -238,7 +242,7 @@ pub fn copy_blocks(
let kernel = try_api!(get_or_load_func(
COPY_BLOCKS_PTX,
COPY_BLOCKS_KERNEL,
key_caches.first().unwrap().dtype(),
Either::Left(key_caches.first().unwrap().dtype()),
dev,
));

Expand Down
142 changes: 135 additions & 7 deletions src/backend/layers.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,140 @@
use candle_core::Tensor;
use either::Either;

use crate::{
backend::{get_or_load_func, ROTARY_EMBDEDDING_KERNEL, ROTARY_EMBDEDDING_PTX},
try_api,
};

use super::dispatch_get_cuda_pointer;

pub fn rotary_embedding(
_positions: Tensor,
_query: &mut Tensor,
_key: &mut Tensor,
_head_size: usize,
_cos_sin_cache: Tensor,
_is_neox: bool,
positions: Tensor,
query: &mut Tensor,
key: &mut Tensor,
head_size: usize,
cos_sin_cache: Tensor,
is_neox: bool,
) {
todo!()
let positions_dev = positions.device();
let Device::Cuda(dev) = positions_dev else {
panic!("Expected the positions to be on a CUDA device.")
};

if !query.device().same_device(positions.device()) {
return Err(APIError::new(format!(
"`query` and `positions` have different devices, got {:?} and {:?} respectively.",
query.device(),
positions.device()
)));
}

if !key.device().same_device(positions.device()) {
return Err(APIError::new(format!(
"`key` and `positions` have different devices, got {:?} and {:?} respectively.",
key.device(),
positions.device()
)));
}

if !cos_sin_cache.device().same_device(positions.device()) {
return Err(APIError::new(format!(
"`cos_sin_cache` and `positions` have different devices, got {:?} and {:?} respectively.",
cos_sin_cache.device(),
positions.device()
)));
}

let num_tokens = query.shape().elem_count() / query.shape().dims().last().unwrap();
let rot_dim = cos_sin_cache.shape().dims().get(1).unwrap();
let num_heads = query.shape().dims().last().unwrap() / head_size;
let num_kv_heads = key.shape().dims().last().unwrap() / head_size;
let query_stride = query.stride().get(key.stride().len() - 2).unwrap();
let key_stride = key.stride().get(key.stride().len() - 2).unwrap();

let launch_conf = LaunchConfig {
grid_dim: (num_tokens.try_into().unwrap(), 1u32, 1u32),
block_dim: (
512.min((num_heads * rot_dim / 2).try_into().unwrap()),
1u32,
1u32,
),
shared_mem_bytes: 0,
};

let positions_ptr = dispatch_get_cuda_pointer(positions);
let key_ptr = dispatch_get_cuda_pointer(key);
let query_ptr = dispatch_get_cuda_pointer(query);
let cos_sin_cache_ptr = dispatch_get_cuda_pointer(cos_sin_cache);

let stream = try_api!(dev.fork_default_stream());

let kernel = if is_neox {
try_api!(get_or_load_func(
ROTARY_EMBDEDDING_PTX,
ROTARY_EMBDEDDING_KERNEL,
Either::Right("_neox"),
dev
))
} else {
try_api!(get_or_load_func(
ROTARY_EMBDEDDING_PTX,
ROTARY_EMBDEDDING_KERNEL,
Either::Right("_normal"),
dev
))
};

try_api!(unsafe {
kernel.launch_on_stream(
&stream,
launch_conf,
(
positions_ptr,
query_ptr,
key_ptr,
cos_sin_cache_ptr,
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size,
),
)
});

/*
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});*/
}
24 changes: 15 additions & 9 deletions src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,25 @@ const RESHAPE_AND_CACHE_PTX: &str = "kernels/reshape_and_cache_kernel.ptx";

const RESHAPE_AND_CACHE_KERNEL: &str = "reshape_and_cache_kernel";

const ROTARY_EMBDEDDING_PTX: &str = "kernels/rotary_embedding_kernel.ptx";

const ROTARY_EMBDEDDING_KERNEL: &str = "rotary_embedding_kernel";

pub fn get_or_load_func(
ptx_file: &'static str,
kernel_base: &str,
dtype: DType,
suffix: Either<DType, &str>,
device: &CudaDevice,
) -> Result<CudaFunction, APIError> {
let suffix = match dtype {
DType::U8 => "_u8",
DType::U32 => "_u32",
DType::I64 => "_i64",
DType::BF16 => "_bf16",
DType::F16 => "_f16",
DType::F32 => "_f32",
DType::F64 => "_f64",
let suffix = match suffix {
Either::Left(DType::U8) => "_u8",
Either::Left(DType::U32) => "_u32",
Either::Left(DType::I64) => "_i64",
Either::Left(DType::BF16) => "_bf16",
Either::Left(DType::F16) => "_f16",
Either::Left(DType::F32) => "_f32",
Either::Left(DType::F64) => "_f64",
Either::Right(data) => data,
};
let kernel = kernel_base.to_owned() + suffix;
device
Expand Down Expand Up @@ -82,6 +87,7 @@ use candle_core::{
cuda_backend::cudarc::driver::{CudaFunction, DeviceRepr},
CudaDevice, DType,
};
use either::Either;
pub use layers::*;
pub use paged_attention::*;
pub use std::ops::Deref;
Expand Down

0 comments on commit 843f045

Please sign in to comment.