Skip to content

Commit

Permalink
Support softcapping (Gemma-2 models) (#86)
Browse files Browse the repository at this point in the history
Support softcapping (Gemma-2 models)
  • Loading branch information
guoqingbao authored Aug 21, 2024
1 parent d0a1060 commit c170b23
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 36 deletions.
2 changes: 1 addition & 1 deletion examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def benchmark():

# avoid generating very short answers
for i in range(len(prompts)):
prompts[i] = prompts[i] + " Respond in more than {} words.".format((int(max_tokens / 10) + 1) * 10)
prompts[i] = prompts[i] + " Respond in more than {} words.".format(int(max_tokens / 10) * 10)

# send 16 chat requests at the same time
tasks: List[asyncio.Task] = []
Expand Down
2 changes: 2 additions & 0 deletions kernels/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ extern "C" {
kv_head_stride: c_int,

dtype: u32,
softscapping: f32,
);

pub fn paged_attention_v2(
Expand All @@ -66,5 +67,6 @@ extern "C" {
kv_head_stride: c_int,

dtype: u32,
softscapping: f32,
);
}
2 changes: 1 addition & 1 deletion kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pub const COPY_BLOCKS_KERNEL: &str =
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
pub mod ffi;
pub mod ffi;
56 changes: 42 additions & 14 deletions kernels/src/pagedattention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ inline __device__ float block_sum(float* red_smem, float sum) {
return VLLM_SHFL_SYNC(sum, 0);
}

inline __device__ float fast_tanh(float x) {
#if defined(__CUDA_ARCH__)
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750)
float y;
asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x));
return y;
#else
return ::tanhf(x);
#endif
#else
return std::tanh(x);
#endif
}

// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template<
Expand All @@ -96,7 +110,8 @@ __device__ void paged_attention_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float softscapping) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
Expand Down Expand Up @@ -212,6 +227,10 @@ __device__ void paged_attention_kernel(
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);

if (softscapping != 1.0) {
qk = fast_tanh(qk / softscapping) * softscapping;
}
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;

Expand Down Expand Up @@ -409,11 +428,12 @@ __global__ void paged_attention_v1_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float softscapping) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, softscapping);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
Expand All @@ -438,11 +458,12 @@ __global__ void paged_attention_v2_kernel(
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
const int kv_head_stride,
const float softscapping) {
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
q_stride, kv_block_stride, kv_head_stride, softscapping);
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -564,7 +585,8 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride,\
softscapping);

// TODO(woosuk): Tune NUM_THREADS.
template<
Expand All @@ -588,7 +610,8 @@ void paged_attention_v1_launcher(
int max_num_blocks_per_seq,
int q_stride,
int kv_block_stride,
int kv_head_stride
int kv_head_stride,
float softscapping
) {

// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
Expand Down Expand Up @@ -652,7 +675,8 @@ void paged_attention_v1_launcher(
max_num_blocks_per_seq, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride, \
softscapping);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -691,7 +715,8 @@ extern "C" void paged_attention_v1(
int32_t kv_block_stride,
int32_t kv_head_stride,

uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32
uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32
float softscapping
) {
if (dtype == 2) {
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
Expand Down Expand Up @@ -719,7 +744,8 @@ extern "C" void paged_attention_v1(
alibi_slopes, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
kv_head_stride,\
softscapping); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
reinterpret_cast<T*>(out), \
Expand Down Expand Up @@ -754,8 +780,8 @@ void paged_attention_v2_launcher(
int max_num_blocks_per_seq,
int q_stride,
int kv_block_stride,
int kv_head_stride

int kv_head_stride,
float softscapping
) {
// int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);

Expand Down Expand Up @@ -825,7 +851,8 @@ void paged_attention_v2_launcher(
max_num_blocks_per_seq, \
q_stride, \
kv_block_stride, \
kv_head_stride);
kv_head_stride,\
softscapping);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -867,7 +894,8 @@ extern "C" void paged_attention_v2(
int32_t kv_block_stride,
int32_t kv_head_stride,

uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32
uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32
float softscapping
) {
if (dtype == 2) {
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
Expand Down
6 changes: 5 additions & 1 deletion src/backend/paged_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::ffi::c_int;

struct PagedAttention {
softmax_scale: f32,

softcapping: f32,
key_cache: Tensor,
value_cache: Tensor,
block_tables: Tensor,
Expand Down Expand Up @@ -187,6 +187,7 @@ impl PagedAttention {
kv_block_stride as c_int,
kv_head_stride as c_int,
internal_type,
self.softcapping,
)
}
} else {
Expand Down Expand Up @@ -223,6 +224,7 @@ impl PagedAttention {
kv_block_stride as c_int,
kv_head_stride as c_int,
internal_type,
self.softcapping,
)
}
}
Expand Down Expand Up @@ -277,6 +279,7 @@ pub fn paged_attention(
context_lens: &Tensor,
max_context_len: usize,
softmax_scale: f32,
softcapping: f32,
) -> Result<Tensor> {
let op = PagedAttention {
softmax_scale,
Expand All @@ -285,6 +288,7 @@ pub fn paged_attention(
block_tables: block_tables.clone(),
context_lens: context_lens.clone(),
max_context_len,
softcapping,
};
q.apply_op1(op)
}
Expand Down
Loading

0 comments on commit c170b23

Please sign in to comment.