diff --git a/setup.py b/setup.py index bd9ce90..270382d 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ def read_requirements(path): "triteia/csrc/ops/marlin_nm.cu", "triteia/csrc/ops/triteia_nm_bmm.cu", "triteia/csrc/ops/triteia_nm_sbmm.cu", + "triteia/csrc/ops/pos_encoding_kernels.cu", ], dlink=True, extra_compile_args={ diff --git a/tests/ops/test_rotary_embedding.py b/tests/ops/test_rotary_embedding.py new file mode 100644 index 0000000..48a5971 --- /dev/null +++ b/tests/ops/test_rotary_embedding.py @@ -0,0 +1,142 @@ +import pytest +import torch + +IS_NEOX_STYLE = [True, False] +DTYPES = [torch.half, torch.bfloat16, torch.float] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] +ROTARY_DIMS = [None, 32] # None means rotary dim == head size +NUM_HEADS = [7, 17] # Arbitrary values for testing +BATCH_SIZES = [1, 5] # Arbitrary values for testing +SEQ_LENS = [11, 8192] # Arbitrary values for testing +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_batched_rotary_embedding( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + seed_everything(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": (1, ) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope.forward_native(positions, query, key) + out_query, out_key = rope.forward(positions, + query, + key, + offsets=torch.zeros(batch_size * seq_len, + dtype=torch.long, + device=device)) + # Compare the results. + torch.testing.assert_close(out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query)) + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_batched_rotary_embedding_multi_lora( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + seed_everything(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + scaling_factors: List[int] = [1, 2, 4] + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": tuple(scaling_factors) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + offset_map = torch.tensor( + list( + accumulate([0] + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ]))) + query_types = torch.randint(0, + len(scaling_factors), (batch_size, seq_len), + device=device) + query_offsets = offset_map[query_types] + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope.forward_native(positions, query, key, + query_offsets) + out_query, out_key = rope.forward(positions, query, key, + query_offsets.flatten()) + # Compare the results. + torch.testing.assert_close(out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query)) + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) \ No newline at end of file diff --git a/triteia/csrc/ops/common/cuda_compat.h b/triteia/csrc/ops/common/cuda_compat.h new file mode 100644 index 0000000..3e7ba60 --- /dev/null +++ b/triteia/csrc/ops/common/cuda_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif \ No newline at end of file diff --git a/triteia/csrc/ops/common/dispatch_utils.h b/triteia/csrc/ops/common/dispatch_utils.h new file mode 100644 index 0000000..84605ee --- /dev/null +++ b/triteia/csrc/ops/common/dispatch_utils.h @@ -0,0 +1,35 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#pragma once + +#include + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) \ No newline at end of file diff --git a/triteia/csrc/ops/ops.cpp b/triteia/csrc/ops/ops.cpp index 261817b..a0c3e8d 100644 --- a/triteia/csrc/ops/ops.cpp +++ b/triteia/csrc/ops/ops.cpp @@ -2,6 +2,9 @@ #include #include #include +#include + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) namespace marlin { int marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, @@ -99,18 +102,18 @@ void bmm_2_4(const torch::Tensor &A, const torch::Tensor &B, } void sbmm_forloop(const torch::Tensor &A, const torch::Tensor &B, - const torch::Tensor &meta, torch::Tensor &C, - const torch::Tensor &s, const torch::Tensor &indices, - torch::Tensor &workspace, const torch::Tensor &starts, - const torch::Tensor &counts, int thread_k = -1, - int thread_n = -1, int sms = -1, int max_par = 8) { + const torch::Tensor &meta, torch::Tensor &C, + const torch::Tensor &s, const torch::Tensor &indices, + torch::Tensor &workspace, const torch::Tensor &starts, + const torch::Tensor &counts, int thread_k = -1, + int thread_n = -1, int sms = -1, int max_par = 8) { for (int i = 0; i < indices.size(0); i++) { int start = starts[i].item(); auto sliced_C = C.slice(0, start, start + counts[i].item()); auto my_workspace = workspace[i]; - marlin::mul_2_4(A.slice(0, start, start + counts[i].item()), B[indices[i]], - meta[indices[i]], sliced_C, s[indices[i]], my_workspace, thread_k, - thread_n, sms, max_par); + marlin::mul_2_4(A.slice(0, start, start + counts[i].item()), + B[indices[i]], meta[indices[i]], sliced_C, s[indices[i]], + my_workspace, thread_k, thread_n, sms, max_par); } } @@ -149,6 +152,18 @@ void sbmm_2_4(const torch::Tensor &A, const torch::Tensor &B, } } // namespace triteia +namespace vllm { +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int64_t head_size, + torch::Tensor &cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int64_t head_size, + torch::Tensor &cos_sin_cache, bool is_neox, + int64_t rot_dim, + torch::Tensor &cos_sin_cache_offsets); +} // namespace vllm + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("mul_2_4", &marlin::mul_2_4, "Marlin FP16xINT4 matmul with 2:4 sparsity."); @@ -156,4 +171,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sbmm_forloop", &triteia::sbmm_forloop, "FP16xINT4 sbmm with 2:4 sparsity."); m.def("sbmm_2_4", &triteia::sbmm_2_4, "FP16xINT4 sbmm with 2:4 sparsity."); + m.def("rotary_embedding", &vllm::rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key."); + m.def("batched_rotary_embedding", &vllm::batched_rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)."); } \ No newline at end of file diff --git a/triteia/csrc/ops/pos_encoding_kernels.cu b/triteia/csrc/ops/pos_encoding_kernels.cu new file mode 100644 index 0000000..357aac0 --- /dev/null +++ b/triteia/csrc/ops/pos_encoding_kernels.cu @@ -0,0 +1,201 @@ +#include +#include +#include + +#include "common/cuda_compat.h" +#include "common/dispatch_utils.h" + +namespace vllm { + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +template +__global__ void batched_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} +void rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); +} + +/* +Batched version of rotary embedding, pack multiple LoRAs together +and process in batched manner. +*/ +void batched_rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int64_t rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] +) { + int64_t num_tokens = cos_sin_cache_offsets.size(0); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); +} +} // namespace vllm diff --git a/triteia/python/capi/rotary_embedding.py b/triteia/python/capi/rotary_embedding.py new file mode 100644 index 0000000..51d5658 --- /dev/null +++ b/triteia/python/capi/rotary_embedding.py @@ -0,0 +1,23 @@ +import torch +import triteia_cuda + +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + triteia_cuda.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + triteia_cuda.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) \ No newline at end of file