-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
477 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import pytest | ||
import torch | ||
|
||
IS_NEOX_STYLE = [True, False] | ||
DTYPES = [torch.half, torch.bfloat16, torch.float] | ||
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] | ||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size | ||
NUM_HEADS = [7, 17] # Arbitrary values for testing | ||
BATCH_SIZES = [1, 5] # Arbitrary values for testing | ||
SEQ_LENS = [11, 8192] # Arbitrary values for testing | ||
SEEDS = [0] | ||
CUDA_DEVICES = [ | ||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) | ||
] | ||
|
||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) | ||
@pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||
@pytest.mark.parametrize("seq_len", SEQ_LENS) | ||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@pytest.mark.parametrize("device", CUDA_DEVICES) | ||
@torch.inference_mode() | ||
def test_batched_rotary_embedding( | ||
is_neox_style: bool, | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
head_size: int, | ||
rotary_dim: Optional[int], | ||
dtype: torch.dtype, | ||
seed: int, | ||
device: str, | ||
max_position: int = 8192, | ||
base: int = 10000, | ||
) -> None: | ||
seed_everything(seed) | ||
torch.set_default_device(device) | ||
if rotary_dim is None: | ||
rotary_dim = head_size | ||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { | ||
"type": "linear", | ||
"factor": (1, ) | ||
}) | ||
rope = rope.to(dtype=dtype) | ||
|
||
positions = torch.randint(0, max_position, (batch_size, seq_len)) | ||
query = torch.randn(batch_size, | ||
seq_len, | ||
num_heads * head_size, | ||
dtype=dtype) | ||
key = torch.randn_like(query) | ||
|
||
# NOTE(woosuk): The reference implementation should be executed first | ||
# because the custom kernel is in-place. | ||
ref_query, ref_key = rope.forward_native(positions, query, key) | ||
out_query, out_key = rope.forward(positions, | ||
query, | ||
key, | ||
offsets=torch.zeros(batch_size * seq_len, | ||
dtype=torch.long, | ||
device=device)) | ||
# Compare the results. | ||
torch.testing.assert_close(out_query, | ||
ref_query, | ||
atol=get_default_atol(out_query), | ||
rtol=get_default_rtol(out_query)) | ||
torch.testing.assert_close(out_key, | ||
ref_key, | ||
atol=get_default_atol(out_key), | ||
rtol=get_default_rtol(out_key)) | ||
|
||
|
||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) | ||
@pytest.mark.parametrize("batch_size", BATCH_SIZES) | ||
@pytest.mark.parametrize("seq_len", SEQ_LENS) | ||
@pytest.mark.parametrize("num_heads", NUM_HEADS) | ||
@pytest.mark.parametrize("head_size", HEAD_SIZES) | ||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) | ||
@pytest.mark.parametrize("dtype", DTYPES) | ||
@pytest.mark.parametrize("seed", SEEDS) | ||
@pytest.mark.parametrize("device", CUDA_DEVICES) | ||
@torch.inference_mode() | ||
def test_batched_rotary_embedding_multi_lora( | ||
is_neox_style: bool, | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
head_size: int, | ||
rotary_dim: Optional[int], | ||
dtype: torch.dtype, | ||
seed: int, | ||
device: str, | ||
max_position: int = 8192, | ||
base: int = 10000, | ||
) -> None: | ||
seed_everything(seed) | ||
torch.set_default_device(device) | ||
if rotary_dim is None: | ||
rotary_dim = head_size | ||
scaling_factors: List[int] = [1, 2, 4] | ||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { | ||
"type": "linear", | ||
"factor": tuple(scaling_factors) | ||
}) | ||
rope = rope.to(dtype=dtype) | ||
|
||
positions = torch.randint(0, max_position, (batch_size, seq_len)) | ||
query = torch.randn(batch_size, | ||
seq_len, | ||
num_heads * head_size, | ||
dtype=dtype) | ||
key = torch.randn_like(query) | ||
|
||
offset_map = torch.tensor( | ||
list( | ||
accumulate([0] + [ | ||
max_position * scaling_factor * 2 | ||
for scaling_factor in scaling_factors[:-1] | ||
]))) | ||
query_types = torch.randint(0, | ||
len(scaling_factors), (batch_size, seq_len), | ||
device=device) | ||
query_offsets = offset_map[query_types] | ||
|
||
# NOTE(woosuk): The reference implementation should be executed first | ||
# because the custom kernel is in-place. | ||
ref_query, ref_key = rope.forward_native(positions, query, key, | ||
query_offsets) | ||
out_query, out_key = rope.forward(positions, query, key, | ||
query_offsets.flatten()) | ||
# Compare the results. | ||
torch.testing.assert_close(out_query, | ||
ref_query, | ||
atol=get_default_atol(out_query), | ||
rtol=get_default_rtol(out_query)) | ||
torch.testing.assert_close(out_key, | ||
ref_key, | ||
atol=get_default_atol(out_key), | ||
rtol=get_default_rtol(out_key)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#pragma once | ||
|
||
#ifdef USE_ROCM | ||
#include <hip/hip_runtime.h> | ||
#endif | ||
|
||
#ifndef USE_ROCM | ||
#define WARP_SIZE 32 | ||
#else | ||
#define WARP_SIZE warpSize | ||
#endif | ||
|
||
#ifndef USE_ROCM | ||
#define VLLM_LDG(arg) __ldg(arg) | ||
#else | ||
#define VLLM_LDG(arg) *(arg) | ||
#endif | ||
|
||
#ifndef USE_ROCM | ||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ | ||
__shfl_xor_sync(uint32_t(-1), var, lane_mask) | ||
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ | ||
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width) | ||
#else | ||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) | ||
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ | ||
__shfl_xor(var, lane_mask, width) | ||
#endif | ||
|
||
#ifndef USE_ROCM | ||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) | ||
#else | ||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) | ||
#endif | ||
|
||
#ifndef USE_ROCM | ||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ | ||
__shfl_down_sync(uint32_t(-1), var, lane_delta) | ||
#else | ||
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) | ||
#endif | ||
|
||
#ifndef USE_ROCM | ||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ | ||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) | ||
#else | ||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ | ||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Adapted from | ||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h | ||
*/ | ||
#pragma once | ||
|
||
#include <torch/all.h> | ||
|
||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) | ||
|
||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) | ||
|
||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) | ||
|
||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH(TYPE, NAME, \ | ||
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) | ||
|
||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ | ||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) | ||
|
||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ | ||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.