diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 4e966cab2..398e03e01 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -714,4 +714,4 @@ fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 -fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 \ No newline at end of file diff --git a/src/liger_kernel/ops/rope.py b/src/liger_kernel/ops/rope.py index 0cd88efeb..872802fed 100644 --- a/src/liger_kernel/ops/rope.py +++ b/src/liger_kernel/ops/rope.py @@ -117,7 +117,126 @@ def _triton_rope( tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) -def rope_forward(q, k, cos, sin): +@triton.jit +def _triton_rope_paper( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim // 2) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + cos_row_idx = pid % (sl) + cos = cos + cos_row_idx * cos_row_stride + sin = sin + cos_row_idx * sin_row_stride + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the even-indexed and odd-indexed elements of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # even-indexed elements of the head + even_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 + ) + even_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 + ) + even_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 < hd + ) + even_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 < hd + ) + q_tile_even = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_even = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to( + sin_row.dtype + ) + + # odd-indexed elements of the head + odd_q_offsets = even_q_offsets + 1 + odd_k_offsets = even_k_offsets + 1 + odd_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd + ) + odd_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd + ) + q_tile_odd = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_odd = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to( + sin_row.dtype + ) + + if not BACKWARD_PASS: + # y_even = x_even * cos - x_odd * sin + # y_odd = x_odd * cos + x_even * sin + new_q_tile_even = q_tile_even * cos_row - q_tile_odd * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) + new_q_tile_odd = q_tile_odd * cos_row + q_tile_even * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) + + new_k_tile_even = k_tile_even * cos_row - k_tile_odd * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) + new_k_tile_odd = k_tile_odd * cos_row + k_tile_even * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) + else: + # dy_even = dx_even * cos + dx_odd * sin + # dy_odd = dx_odd * cos - dx_even * sin + new_q_tile_even = q_tile_even * cos_row + q_tile_odd * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) + new_q_tile_odd = q_tile_odd * cos_row - q_tile_even * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) + + new_k_tile_even = k_tile_even * cos_row + k_tile_odd * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) + new_k_tile_odd = k_tile_odd * cos_row - k_tile_even * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) + + +def rope_forward(q, k, cos, sin, paper_form): # transpose it back to the physical shape because Triton looks at the physical storage # note: q and k are incontiguous before the transformation and will become contiguous after transpose @@ -139,30 +258,52 @@ def rope_forward(q, k, cos, sin): cos = cos.contiguous() sin = sin.contiguous() - _triton_rope[(n_row,)]( - q, - q.stride(1), - k, - k.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=False, - ) + if not paper_form: + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + else: + _triton_rope_paper[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) return q.transpose(1, 2), k.transpose(1, 2), cos, sin -def rope_backward(dq, dk, cos, sin): +def rope_backward(dq, dk, cos, sin, paper_form): dq = dq.transpose(1, 2) dk = dk.transpose(1, 2) @@ -180,51 +321,85 @@ def rope_backward(dq, dk, cos, sin): dk = dk.contiguous() # backward is similar to forward except swapping few ops - _triton_rope[(n_row,)]( - dq, - dq.stride(1), - dk, - dk.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=True, - ) + if not paper_form: + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + else: + _triton_rope_paper[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) return dq.transpose(1, 2), dk.transpose(1, 2) class LigerRopeFunction(torch.autograd.Function): """ - Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that - this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different - than the original RoPE paper. + Triton implementation of the Rotary Positional Embedding (RoPE) operation. + This implements both HuggingFace Llama & Mistral version and the original RoPE paper version. - Please find the corresponding HuggingFace implementation here: + Please find the corresponding HuggingFace Llama & Mistral implementation here: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + Please find the corresponding HuggingFace paper-form implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/roformer/modeling_roformer.py#L309 + For more details about the rotation matrix used here, please refer to: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 """ @staticmethod - def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + def forward( + ctx, + q, + k, + cos, + sin, + position_ids=None, + unsqueeze_dim=1, + paper_form: bool = False, + ): """ q size: (bsz, n_q_head, seq_len, head_dim) k size: (bsz, n_kv_head, seq_len, head_dim) cos size: (1, seq_len, head_dim) sin size: (1, seq_len, head_dim) """ - q, k, cos, sin = rope_forward(q, k, cos, sin) + q, k, cos, sin = rope_forward(q, k, cos, sin, paper_form) + ctx.paper_form = paper_form ctx.save_for_backward(cos, sin) return q, k @@ -237,5 +412,5 @@ def backward(ctx, dq, dk): """ cos, sin = ctx.saved_tensors - dq, dk = rope_backward(dq, dk, cos, sin) - return dq, dk, None, None, None, None + dq, dk = rope_backward(dq, dk, cos, sin, ctx.paper_form) + return dq, dk, None, None, None, None, None diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 45ad6159a..66316056b 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -165,8 +165,12 @@ def liger_rms_norm( return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) -def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) +def liger_rope( + q, k, cos, sin, position_ids=None, unsqueeze_dim=1, paper_form: bool = False +): + return LigerRopeFunction.apply( + q, k, cos, sin, position_ids, unsqueeze_dim, paper_form + ) def liger_swiglu(a, b): diff --git a/src/liger_kernel/transformers/rope.py b/src/liger_kernel/transformers/rope.py index a40b29af3..8b7ea8e61 100644 --- a/src/liger_kernel/transformers/rope.py +++ b/src/liger_kernel/transformers/rope.py @@ -1,7 +1,9 @@ from liger_kernel.ops.rope import LigerRopeFunction -def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def liger_rotary_pos_emb( + q, k, cos, sin, position_ids=None, unsqueeze_dim=1, paper_form: bool = False +): """ Applies Rotary Positional Embedding (RoPE) operation to query and key states. @@ -12,9 +14,12 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. + paper_form (bool, optional): Whether to use the paper-form RoPE rotary matrix. Defaults to false. Returns: Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. """ - return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) + return LigerRopeFunction.apply( + q, k, cos, sin, position_ids, unsqueeze_dim, paper_form + ) diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 74080b57f..7f341a6d0 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -6,12 +6,19 @@ LlamaRotaryEmbedding, apply_rotary_pos_emb, ) +from transformers.models.roformer.modeling_roformer import ( + RoFormerSelfAttention, + RoFormerSinusoidalPositionalEmbedding, +) from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.transformers.functional import liger_rope from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.utils import infer_device +apply_paper_rotary_pos_emb = RoFormerSelfAttention.apply_rotary_position_embeddings + + device = infer_device() SLEEP_SECONDS = 0.1 @@ -46,10 +53,16 @@ ), ], ) +@pytest.mark.parametrize( + "paper_form", + [ + False, + True, + ], +) def test_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol, paper_form ): - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) _tensor_q = ( torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) @@ -69,12 +82,25 @@ def test_correctness( q2 = _tensor_q.clone().requires_grad_(True) k2 = _tensor_k.clone().requires_grad_(True) - pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - cos, sin = rotary_emb(k1, pos_ids) - # validate forward pass - hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids) - tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin) + if not paper_form: + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k1, pos_ids) + + hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids) + else: + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + + sinusoidal_pos = rotary_emb((bsz, seq_len)).unsqueeze(0).to(dtype) + + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + hf_q, hf_k = apply_paper_rotary_pos_emb(sinusoidal_pos.unsqueeze(0), q1, k1) + + tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin, paper_form=paper_form) assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol) assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol) @@ -111,8 +137,15 @@ def test_correctness( (torch.bfloat16, 1e-1, 1e-5), ], ) +@pytest.mark.parametrize( + "paper_form", + [ + False, + True, + ], +) def test_functional_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol, paper_form ): _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) @@ -123,13 +156,27 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) - - pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - cos, sin = rotary_emb(k1, pos_ids) - - functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin) - class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin) + if not paper_form: + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k1, pos_ids) + else: + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + sinusoidal_pos = rotary_emb((bsz, seq_len)).unsqueeze(0).to(dtype) + + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + + functional_q, functional_k = liger_rope( + q=q1, + k=k1, + cos=cos, + sin=sin, + paper_form=paper_form, + ) + class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin, None, 1, paper_form) assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol) assert torch.allclose(functional_k, class_k, atol=atol, rtol=rtol)