-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6c343d4
commit 843f045
Showing
4 changed files
with
261 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#include <stdint.h> | ||
|
||
#ifndef USE_ROCM | ||
#define VLLM_LDG(arg) __ldg(arg) | ||
#else | ||
#define VLLM_LDG(arg) *(arg) | ||
#endif | ||
|
||
template<typename scalar_t, bool IS_NEOX> | ||
inline __device__ void apply_rotary_embedding( | ||
scalar_t* __restrict__ arr, | ||
const scalar_t* __restrict__ cos_ptr, | ||
const scalar_t* __restrict__ sin_ptr, | ||
int rot_offset, | ||
int embed_dim) | ||
{ | ||
int x_index, y_index; | ||
scalar_t cos, sin; | ||
if (IS_NEOX) { | ||
// GPT-NeoX style rotary embedding. | ||
x_index = rot_offset; | ||
y_index = embed_dim + rot_offset; | ||
cos = VLLM_LDG(cos_ptr + x_index); | ||
sin = VLLM_LDG(sin_ptr + x_index); | ||
} else { | ||
// GPT-J style rotary embedding. | ||
x_index = 2 * rot_offset; | ||
y_index = 2 * rot_offset + 1; | ||
cos = VLLM_LDG(cos_ptr + x_index / 2); | ||
sin = VLLM_LDG(sin_ptr + x_index / 2); | ||
} | ||
|
||
const scalar_t x = arr[x_index]; | ||
const scalar_t y = arr[y_index]; | ||
arr[x_index] = x * cos - y * sin; | ||
arr[y_index] = y * cos + x * sin; | ||
} | ||
|
||
template<typename scalar_t, bool IS_NEOX> | ||
__global__ void rotary_embedding_kernel( | ||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] | ||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] | ||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] | ||
const int rot_dim, | ||
const int64_t query_stride, | ||
const int64_t key_stride, | ||
const int num_heads, | ||
const int num_kv_heads, | ||
const int head_size) { | ||
// Each thread block is responsible for one token. | ||
const int token_idx = blockIdx.x; | ||
int64_t pos = positions[token_idx]; | ||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||
|
||
const int embed_dim = rot_dim / 2; | ||
const scalar_t* cos_ptr = cache_ptr; | ||
const scalar_t* sin_ptr = cache_ptr + embed_dim; | ||
|
||
const int nq = num_heads * embed_dim; | ||
for (int i = threadIdx.x; i < nq; i += blockDim.x) { | ||
const int head_idx = i / embed_dim; | ||
const int64_t token_head = token_idx * query_stride + head_idx * head_size; | ||
const int rot_offset = i % embed_dim; | ||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, | ||
sin_ptr, rot_offset, embed_dim); | ||
} | ||
|
||
const int nk = num_kv_heads * embed_dim; | ||
for (int i = threadIdx.x; i < nk; i += blockDim.x) { | ||
const int head_idx = i / embed_dim; | ||
const int64_t token_head = token_idx * key_stride + head_idx * head_size; | ||
const int rot_offset = i % embed_dim; | ||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, | ||
sin_ptr, rot_offset, embed_dim); | ||
} | ||
} | ||
|
||
extern "C" __global__ void rotary_embedding_kernel_neox( | ||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] | ||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] | ||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] | ||
const int rot_dim, | ||
const int64_t query_stride, | ||
const int64_t key_stride, | ||
const int num_heads, | ||
const int num_kv_heads, | ||
const int head_size) { | ||
rotary_embedding_kernel<scalar_t, true>(positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); | ||
} | ||
|
||
extern "C" __global__ void rotary_embedding_kernel_normal( | ||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] | ||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] | ||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] | ||
const int rot_dim, | ||
const int64_t query_stride, | ||
const int64_t key_stride, | ||
const int num_heads, | ||
const int num_kv_heads, | ||
const int head_size) { | ||
rotary_embedding_kernel<scalar_t, true>(positions, query, key, cos_sin_cache, rot_dim, query_stride, key_stride, num_heads, num_kv_heads, head_size); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,140 @@ | ||
use candle_core::Tensor; | ||
use either::Either; | ||
|
||
use crate::{ | ||
backend::{get_or_load_func, ROTARY_EMBDEDDING_KERNEL, ROTARY_EMBDEDDING_PTX}, | ||
try_api, | ||
}; | ||
|
||
use super::dispatch_get_cuda_pointer; | ||
|
||
pub fn rotary_embedding( | ||
_positions: Tensor, | ||
_query: &mut Tensor, | ||
_key: &mut Tensor, | ||
_head_size: usize, | ||
_cos_sin_cache: Tensor, | ||
_is_neox: bool, | ||
positions: Tensor, | ||
query: &mut Tensor, | ||
key: &mut Tensor, | ||
head_size: usize, | ||
cos_sin_cache: Tensor, | ||
is_neox: bool, | ||
) { | ||
todo!() | ||
let positions_dev = positions.device(); | ||
let Device::Cuda(dev) = positions_dev else { | ||
panic!("Expected the positions to be on a CUDA device.") | ||
}; | ||
|
||
if !query.device().same_device(positions.device()) { | ||
return Err(APIError::new(format!( | ||
"`query` and `positions` have different devices, got {:?} and {:?} respectively.", | ||
query.device(), | ||
positions.device() | ||
))); | ||
} | ||
|
||
if !key.device().same_device(positions.device()) { | ||
return Err(APIError::new(format!( | ||
"`key` and `positions` have different devices, got {:?} and {:?} respectively.", | ||
key.device(), | ||
positions.device() | ||
))); | ||
} | ||
|
||
if !cos_sin_cache.device().same_device(positions.device()) { | ||
return Err(APIError::new(format!( | ||
"`cos_sin_cache` and `positions` have different devices, got {:?} and {:?} respectively.", | ||
cos_sin_cache.device(), | ||
positions.device() | ||
))); | ||
} | ||
|
||
let num_tokens = query.shape().elem_count() / query.shape().dims().last().unwrap(); | ||
let rot_dim = cos_sin_cache.shape().dims().get(1).unwrap(); | ||
let num_heads = query.shape().dims().last().unwrap() / head_size; | ||
let num_kv_heads = key.shape().dims().last().unwrap() / head_size; | ||
let query_stride = query.stride().get(key.stride().len() - 2).unwrap(); | ||
let key_stride = key.stride().get(key.stride().len() - 2).unwrap(); | ||
|
||
let launch_conf = LaunchConfig { | ||
grid_dim: (num_tokens.try_into().unwrap(), 1u32, 1u32), | ||
block_dim: ( | ||
512.min((num_heads * rot_dim / 2).try_into().unwrap()), | ||
1u32, | ||
1u32, | ||
), | ||
shared_mem_bytes: 0, | ||
}; | ||
|
||
let positions_ptr = dispatch_get_cuda_pointer(positions); | ||
let key_ptr = dispatch_get_cuda_pointer(key); | ||
let query_ptr = dispatch_get_cuda_pointer(query); | ||
let cos_sin_cache_ptr = dispatch_get_cuda_pointer(cos_sin_cache); | ||
|
||
let stream = try_api!(dev.fork_default_stream()); | ||
|
||
let kernel = if is_neox { | ||
try_api!(get_or_load_func( | ||
ROTARY_EMBDEDDING_PTX, | ||
ROTARY_EMBDEDDING_KERNEL, | ||
Either::Right("_neox"), | ||
dev | ||
)) | ||
} else { | ||
try_api!(get_or_load_func( | ||
ROTARY_EMBDEDDING_PTX, | ||
ROTARY_EMBDEDDING_KERNEL, | ||
Either::Right("_normal"), | ||
dev | ||
)) | ||
}; | ||
|
||
try_api!(unsafe { | ||
kernel.launch_on_stream( | ||
&stream, | ||
launch_conf, | ||
( | ||
positions_ptr, | ||
query_ptr, | ||
key_ptr, | ||
cos_sin_cache_ptr, | ||
rot_dim, | ||
query_stride, | ||
key_stride, | ||
num_heads, | ||
num_kv_heads, | ||
head_size, | ||
), | ||
) | ||
}); | ||
|
||
/* | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
VLLM_DISPATCH_FLOATING_TYPES( | ||
query.scalar_type(), | ||
"rotary_embedding", | ||
[&] { | ||
if (is_neox) { | ||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( | ||
positions.data_ptr<int64_t>(), | ||
query.data_ptr<scalar_t>(), | ||
key.data_ptr<scalar_t>(), | ||
cos_sin_cache.data_ptr<scalar_t>(), | ||
rot_dim, | ||
query_stride, | ||
key_stride, | ||
num_heads, | ||
num_kv_heads, | ||
head_size); | ||
} else { | ||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>( | ||
positions.data_ptr<int64_t>(), | ||
query.data_ptr<scalar_t>(), | ||
key.data_ptr<scalar_t>(), | ||
cos_sin_cache.data_ptr<scalar_t>(), | ||
rot_dim, | ||
query_stride, | ||
key_stride, | ||
num_heads, | ||
num_kv_heads, | ||
head_size); | ||
} | ||
});*/ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters