Skip to content

Commit

Permalink
add rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Sep 27, 2024
1 parent c1e7da6 commit 8135d3a
Show file tree
Hide file tree
Showing 7 changed files with 477 additions and 8 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def read_requirements(path):
"triteia/csrc/ops/marlin_nm.cu",
"triteia/csrc/ops/triteia_nm_bmm.cu",
"triteia/csrc/ops/triteia_nm_sbmm.cu",
"triteia/csrc/ops/pos_encoding_kernels.cu",
],
dlink=True,
extra_compile_args={
Expand Down
142 changes: 142 additions & 0 deletions tests/ops/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import pytest
import torch

IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions,
query,
key,
offsets=torch.zeros(batch_size * seq_len,
dtype=torch.long,
device=device))
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))


@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
seed_everything(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype)

positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)

offset_map = torch.tensor(
list(
accumulate([0] + [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
])))
query_types = torch.randint(0,
len(scaling_factors), (batch_size, seq_len),
device=device)
query_offsets = offset_map[query_types]

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
torch.testing.assert_close(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
torch.testing.assert_close(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
49 changes: 49 additions & 0 deletions triteia/csrc/ops/common/cuda_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif

#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor(var, lane_mask, width)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
#else
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif

#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
35 changes: 35 additions & 0 deletions triteia/csrc/ops/common/dispatch_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#pragma once

#include <torch/all.h>

#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
34 changes: 26 additions & 8 deletions triteia/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <cuda_runtime.h>
#include <torch/all.h>
#include <torch/python.h>
#include <torch/library.h>

#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)

namespace marlin {
int marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C,
Expand Down Expand Up @@ -99,18 +102,18 @@ void bmm_2_4(const torch::Tensor &A, const torch::Tensor &B,
}

void sbmm_forloop(const torch::Tensor &A, const torch::Tensor &B,
const torch::Tensor &meta, torch::Tensor &C,
const torch::Tensor &s, const torch::Tensor &indices,
torch::Tensor &workspace, const torch::Tensor &starts,
const torch::Tensor &counts, int thread_k = -1,
int thread_n = -1, int sms = -1, int max_par = 8) {
const torch::Tensor &meta, torch::Tensor &C,
const torch::Tensor &s, const torch::Tensor &indices,
torch::Tensor &workspace, const torch::Tensor &starts,
const torch::Tensor &counts, int thread_k = -1,
int thread_n = -1, int sms = -1, int max_par = 8) {
for (int i = 0; i < indices.size(0); i++) {
int start = starts[i].item<int>();
auto sliced_C = C.slice(0, start, start + counts[i].item<int>());
auto my_workspace = workspace[i];
marlin::mul_2_4(A.slice(0, start, start + counts[i].item<int>()), B[indices[i]],
meta[indices[i]], sliced_C, s[indices[i]], my_workspace, thread_k,
thread_n, sms, max_par);
marlin::mul_2_4(A.slice(0, start, start + counts[i].item<int>()),
B[indices[i]], meta[indices[i]], sliced_C, s[indices[i]],
my_workspace, thread_k, thread_n, sms, max_par);
}
}

Expand Down Expand Up @@ -149,11 +152,26 @@ void sbmm_2_4(const torch::Tensor &A, const torch::Tensor &B,
}
} // namespace triteia

namespace vllm {
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
torch::Tensor &key, int64_t head_size,
torch::Tensor &cos_sin_cache, bool is_neox);

void batched_rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
torch::Tensor &key, int64_t head_size,
torch::Tensor &cos_sin_cache, bool is_neox,
int64_t rot_dim,
torch::Tensor &cos_sin_cache_offsets);
} // namespace vllm

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("mul_2_4", &marlin::mul_2_4,
"Marlin FP16xINT4 matmul with 2:4 sparsity.");
m.def("bmm_2_4", &triteia::bmm_2_4, "FP16xINT4 bmm with 2:4 sparsity.");
m.def("sbmm_forloop", &triteia::sbmm_forloop,
"FP16xINT4 sbmm with 2:4 sparsity.");
m.def("sbmm_2_4", &triteia::sbmm_2_4, "FP16xINT4 sbmm with 2:4 sparsity.");
m.def("rotary_embedding", &vllm::rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key.");
m.def("batched_rotary_embedding", &vllm::batched_rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras).");
}
Loading

0 comments on commit 8135d3a

Please sign in to comment.