From edd4d863d842d0da0f30c590bae2f818e30bb950 Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Tue, 10 Dec 2024 21:44:43 +0800 Subject: [PATCH 1/8] Add original RoPE kernel --- src/liger_kernel/ops/rope_paper.py | 241 ++++++++++++++++++++ src/liger_kernel/transformers/__init__.py | 3 + src/liger_kernel/transformers/functional.py | 5 + src/liger_kernel/transformers/rope_paper.py | 20 ++ 4 files changed, 269 insertions(+) create mode 100644 src/liger_kernel/ops/rope_paper.py create mode 100644 src/liger_kernel/transformers/rope_paper.py diff --git a/src/liger_kernel/ops/rope_paper.py b/src/liger_kernel/ops/rope_paper.py new file mode 100644 index 000000000..aca347ed5 --- /dev/null +++ b/src/liger_kernel/ops/rope_paper.py @@ -0,0 +1,241 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _triton_rope_paper( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim // 2) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. cos and sin matrices are already in the shape (1, seq_len, head_dim // 2), so we simply load the entire matrix. + cos_row_idx = pid % (sl) + cos = cos + cos_row_idx * cos_row_stride + sin = sin + cos_row_idx * sin_row_stride + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the even-indexed and odd-indexed elements of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # even-indexed elements of the head + even_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 + ) + even_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 + ) + even_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 < hd + ) + even_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 < hd + ) + q_tile_even = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_even = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to( + sin_row.dtype + ) + + # odd-indexed elements of the head + odd_q_offsets = even_q_offsets + 1 + odd_k_offsets = even_k_offsets + 1 + odd_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd + ) + odd_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd + ) + q_tile_odd = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_odd = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to( + sin_row.dtype + ) + + if not BACKWARD_PASS: + # y_even = x_even * cos - x_odd * sin + # y_odd = x_odd * cos + x_even * sin + new_q_tile_even = q_tile_even * cos_row - q_tile_odd * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) + new_q_tile_odd = q_tile_odd * cos_row + q_tile_even * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) + + new_k_tile_even = k_tile_even * cos_row - k_tile_odd * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) + new_k_tile_odd = k_tile_odd * cos_row + k_tile_even * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) + else: + # dy_even = dx_even * cos + dx_odd * sin + # dy_odd = dx_odd * cos - dx_even * sin + new_q_tile_even = q_tile_even * cos_row + q_tile_odd * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) + new_q_tile_odd = q_tile_odd * cos_row - q_tile_even * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) + + new_k_tile_even = k_tile_even * cos_row + k_tile_odd * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) + new_k_tile_odd = k_tile_odd * cos_row - k_tile_even * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) + + +def rope_paper_forward(q, k, cos, sin): + + # transpose it back to the physical shape because Triton looks at the physical storage + # note: q and k are incontiguous before the transformation and will become contiguous after transpose + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = q.shape + n_kv_head = k.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous + q = q.contiguous() + k = k.contiguous() + cos = cos.contiguous() + sin = sin.contiguous() + + _triton_rope_paper[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + return q.transpose(1, 2), k.transpose(1, 2), cos, sin + + +def rope_paper_backward(dq, dk, cos, sin): + dq = dq.transpose(1, 2) + dk = dk.transpose(1, 2) + + batch_size, seq_len, n_q_head, head_dim = dq.shape + n_kv_head = dk.shape[2] + pad_hd = triton.next_power_of_2(head_dim) + pad_n_q_head = triton.next_power_of_2(n_q_head) + pad_n_kv_head = triton.next_power_of_2(n_kv_head) + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) + + n_row = batch_size * seq_len + + # ensure dq and dk are contiguous + dq = dq.contiguous() + dk = dk.contiguous() + + # backward is similar to forward except swapping few ops + _triton_rope_paper[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + return dq.transpose(1, 2), dk.transpose(1, 2) + + +class LigerRopePaperFunction(torch.autograd.Function): + """ + Triton implementation of the orignal Rotary Positional Embedding (RoPE) operation from RoFormer. + + Please find the corresponding HuggingFace implementation here: + https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/roformer/modeling_roformer.py#L309 + + """ + + @staticmethod + def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + q size: (bsz, n_q_head, seq_len, head_dim) + k size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim // 2) + sin size: (1, seq_len, head_dim // 2) + """ + q, k, cos, sin = rope_paper_forward(q, k, cos, sin) + ctx.save_for_backward(cos, sin) + return q, k + + def backward(ctx, dq, dk): + """ + dq size: (bsz, n_q_head, seq_len, head_dim) + dk size: (bsz, n_kv_head, seq_len, head_dim) + cos size: (1, seq_len, head_dim // 2) + sin size: (1, seq_len, head_dim // 2) + """ + + cos, sin = ctx.saved_tensors + dq, dk = rope_paper_backward(dq, dk, cos, sin) + return dq, dk, None, None, None, None diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 4f67fe8cf..84e008cae 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -25,6 +25,9 @@ from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401 from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 +from liger_kernel.transformers.rope_paper import ( # noqa: F401 + liger_rotary_paper_pos_emb, +) from liger_kernel.transformers.swiglu import ( # noqa: F401 LigerBlockSparseTop2MLP, LigerPhi3SwiGLUMLP, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 45ad6159a..6151cccbf 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -13,6 +13,7 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction +from liger_kernel.ops.rope_paper import LigerRopePaperFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction @@ -169,5 +170,9 @@ def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) +def liger_rope_paper(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + return LigerRopePaperFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) + + def liger_swiglu(a, b): return LigerSiLUMulFunction.apply(a, b) diff --git a/src/liger_kernel/transformers/rope_paper.py b/src/liger_kernel/transformers/rope_paper.py new file mode 100644 index 000000000..2362e004d --- /dev/null +++ b/src/liger_kernel/transformers/rope_paper.py @@ -0,0 +1,20 @@ +from liger_kernel.ops.rope_paper import LigerRopePaperFunction + + +def liger_rotary_paper_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """ + Applies Rotary Positional Embedding (RoPE) operation to query and key states. + + Args: + q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). + k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). + cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). + sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). + position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. + unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. + """ + + return LigerRopePaperFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) From 16a5ef153cf9279d8ab8ceb336783cb7de29344e Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Tue, 10 Dec 2024 21:48:34 +0800 Subject: [PATCH 2/8] Add Unit Tests --- test/transformers/test_rope_paper.py | 163 +++++++++++++++++++++++++ test/transformers/test_transformers.py | 1 + 2 files changed, 164 insertions(+) create mode 100644 test/transformers/test_rope_paper.py diff --git a/test/transformers/test_rope_paper.py b/test/transformers/test_rope_paper.py new file mode 100644 index 000000000..d2ae6f061 --- /dev/null +++ b/test/transformers/test_rope_paper.py @@ -0,0 +1,163 @@ +from test.utils import supports_bfloat16 + +import pytest +import torch +from transformers.models.roformer.modeling_roformer import ( + RoFormerSelfAttention, + RoFormerSinusoidalPositionalEmbedding, +) + +from liger_kernel.ops.rope_paper import LigerRopePaperFunction +from liger_kernel.transformers.functional import liger_rope_paper +from liger_kernel.transformers.rope_paper import liger_rotary_paper_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() + +SLEEP_SECONDS = 0.1 + +apply_rotary_pos_emb = RoFormerSelfAttention.apply_rotary_position_embeddings + + +@pytest.mark.parametrize( + "bsz, seq_len, num_q_heads, num_kv_heads, head_dim", + [ + (1, 128, 32, 32, 64), + (2, 128, 32, 32, 64), + # different q/k heads + (1, 128, 32, 8, 64), + (2, 128, 32, 8, 64), + # weird shapes + # HuggingFace llama/mistral source code doesn't support odd head dimension + # so we don't test it here + (3, 423, 73, 213, 92), + (3, 423, 73, 155, 92), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + pytest.param( + torch.bfloat16, + 1e-1, + 1e-5, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + ], +) +def test_correctness( + bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol +): + + _tensor_q = ( + torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) + .transpose(1, 2) + .to(dtype) + ) + + _tensor_k = ( + torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) + .transpose(1, 2) + .to(dtype) + ) + + q1 = _tensor_q.clone().requires_grad_(True) + k1 = _tensor_k.clone().requires_grad_(True) + + q2 = _tensor_q.clone().requires_grad_(True) + k2 = _tensor_k.clone().requires_grad_(True) + + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + sinusoidal_pos = rotary_emb((bsz, seq_len))[None, :, :].to(dtype) + + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + # validate forward pass + hf_q, hf_k = apply_rotary_pos_emb(sinusoidal_pos[None, :, :, :], q1, k1) + tt_q, tt_k = liger_rotary_paper_pos_emb(q2, k2, cos, sin) + assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol) + assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol) + + # validate backward pass + dq, dk = ( + torch.randn_like(hf_q, device=device), + torch.randn_like(hf_k, device=device).to(dtype), + ) + + q1_grad, k1_grad = torch.autograd.grad( + (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True + ) + q2_grad, k2_grad = torch.autograd.grad( + (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True + ) + + assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) + assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize( + "bsz, seq_len, num_q_heads, num_kv_heads, head_dim", + [ + (1, 2, 2, 2, 8), + (1, 2, 1, 2, 8), + # weird shapes + (9, 7, 41, 41, 41), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + (torch.bfloat16, 1e-1, 1e-5), + ], +) +def test_functional_correctness( + bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol +): + _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) + _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) + + q1 = _q.clone().requires_grad_(True) + q2 = _q.clone().requires_grad_(True) + + k1 = _k.clone().requires_grad_(True) + k2 = _k.clone().requires_grad_(True) + + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + + sinusoidal_pos = rotary_emb((bsz, seq_len))[None, None, :, :].to(dtype) + + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + functional_q, functional_k = liger_rope_paper(q1, k1, cos, sin) + class_q, class_k = LigerRopePaperFunction.apply(q2, k2, cos, sin) + + assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol) + assert torch.allclose(functional_k, class_k, atol=atol, rtol=rtol) + + dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k) + + dq1, dk1 = dq.clone(), dk.clone() + dq2, dk2 = dq.clone(), dk.clone() + + q1_grad, k1_grad = torch.autograd.grad( + (functional_q, functional_k), + (q1, k1), + (dq1, dk1), + allow_unused=True, + ) + + q2_grad, k2_grad = torch.autograd.grad( + (class_q, class_k), + (q2, k2), + (dq2, dk2), + allow_unused=True, + ) + + assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) + assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) diff --git a/test/transformers/test_transformers.py b/test/transformers/test_transformers.py index 9601229ec..3e6d79123 100644 --- a/test/transformers/test_transformers.py +++ b/test/transformers/test_transformers.py @@ -12,6 +12,7 @@ def test_import_from_root(): LigerPhi3SwiGLUMLP, LigerRMSNorm, LigerSwiGLUMLP, + liger_rotary_paper_pos_emb, liger_rotary_pos_emb, ) except Exception: From 0c4ee6d0b754dbcd051a6c979661b4eeb571bfe7 Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Tue, 10 Dec 2024 21:50:22 +0800 Subject: [PATCH 3/8] Add benchmark --- benchmark/data/all_benchmark_data.csv | 64 ++++++ benchmark/scripts/benchmark_rope_paper.py | 237 ++++++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 benchmark/scripts/benchmark_rope_paper.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 4e966cab2..27650665b 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -715,3 +715,67 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 +rope_paper,liger,forward,speed,ms,H,hidden size,512,0.027648000046610832,0.027648000046610832,0.028672000393271446,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:14,0.3.1 +rope_paper,liger,forward,speed,ms,H,hidden size,2048,0.1515520066022873,0.15052799880504608,0.15360000729560852,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:14,0.3.1 +rope_paper,liger,forward,speed,ms,H,hidden size,8192,0.5099520087242126,0.5079039931297302,0.5120000243186951,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:14,0.3.1 +rope_paper,huggingface,forward,speed,ms,H,hidden size,512,0.12800000607967377,0.12492799758911133,0.13209599256515503,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:16,0.3.1 +rope_paper,huggingface,forward,speed,ms,H,hidden size,2048,0.17203199863433838,0.17100800573825836,0.17203199863433838,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:16,0.3.1 +rope_paper,huggingface,forward,speed,ms,H,hidden size,8192,0.5396479964256287,0.5386239886283875,0.5416960120201111,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:16,0.3.1 +rope_paper,liger,backward,speed,ms,H,hidden size,512,0.021503999829292297,0.01945599913597107,0.02457600086927414,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:18,0.3.1 +rope_paper,liger,backward,speed,ms,H,hidden size,2048,0.13926400244235992,0.1382399946451187,0.14028799533843994,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:18,0.3.1 +rope_paper,liger,backward,speed,ms,H,hidden size,8192,0.49561598896980286,0.4935680031776428,0.4976640045642853,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:18,0.3.1 +rope_paper,huggingface,backward,speed,ms,H,hidden size,512,0.22732800245285034,0.22466561198234558,0.2314240038394928,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:20,0.3.1 +rope_paper,huggingface,backward,speed,ms,H,hidden size,2048,0.20787200331687927,0.20684799551963806,0.20787200331687927,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:20,0.3.1 +rope_paper,huggingface,backward,speed,ms,H,hidden size,8192,0.7290880084037781,0.7290880084037781,0.7301120162010193,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:20,0.3.1 +rope_paper,liger,full,speed,ms,H,hidden size,512,0.14233599603176117,0.10444799810647964,0.14622725546360016,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:22,0.3.1 +rope_paper,liger,full,speed,ms,H,hidden size,2048,0.28672000765800476,0.28467199206352234,0.2887679934501648,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:22,0.3.1 +rope_paper,liger,full,speed,ms,H,hidden size,8192,1.001471996307373,0.9983999729156494,1.0045440196990967,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:22,0.3.1 +rope_paper,huggingface,full,speed,ms,H,hidden size,512,0.44441598653793335,0.4413439929485321,0.45977601408958435,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:24,0.3.1 +rope_paper,huggingface,full,speed,ms,H,hidden size,2048,0.4249599874019623,0.42393600940704346,0.4280320107936859,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:24,0.3.1 +rope_paper,huggingface,full,speed,ms,H,hidden size,8192,1.2636159658432007,1.2625919580459595,1.265663981437683,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:24,0.3.1 +rope_paper,liger,full,memory,MB,H,hidden size,512,5.25,5.25,5.25,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:25,0.3.1 +rope_paper,liger,full,memory,MB,H,hidden size,2048,21.0,21.0,21.0,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:25,0.3.1 +rope_paper,liger,full,memory,MB,H,hidden size,8192,84.0,84.0,84.0,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:25,0.3.1 +rope_paper,huggingface,full,memory,MB,H,hidden size,512,14.3125,14.3125,14.3125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:26,0.3.1 +rope_paper,huggingface,full,memory,MB,H,hidden size,2048,57.25,57.25,57.25,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:26,0.3.1 +rope_paper,huggingface,full,memory,MB,H,hidden size,8192,229.0,229.0,229.0,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:26,0.3.1 +rope_paper,liger,forward,speed,ms,T,sequence length,1024,0.2836480140686035,0.2815999984741211,0.28569599986076355,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 +rope_paper,liger,forward,speed,ms,T,sequence length,2048,0.5089280009269714,0.506879985332489,0.5120000243186951,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 +rope_paper,liger,forward,speed,ms,T,sequence length,4096,0.9666560292243958,0.9646080136299133,0.9697279930114746,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 +rope_paper,liger,forward,speed,ms,T,sequence length,8192,1.8821120262145996,1.8800640106201172,1.8851840496063232,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 +rope_paper,liger,forward,speed,ms,T,sequence length,16384,3.7099521160125732,3.7058560848236084,3.7130239009857178,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 +rope_paper,huggingface,forward,speed,ms,T,sequence length,1024,0.289792001247406,0.2887679934501648,0.289792001247406,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 +rope_paper,huggingface,forward,speed,ms,T,sequence length,2048,0.5396479964256287,0.5386239886283875,0.5416960120201111,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 +rope_paper,huggingface,forward,speed,ms,T,sequence length,4096,1.0240000486373901,1.0219520330429077,1.026047945022583,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 +rope_paper,huggingface,forward,speed,ms,T,sequence length,8192,1.9967999458312988,1.994752049446106,1.9988479614257812,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 +rope_paper,huggingface,forward,speed,ms,T,sequence length,16384,3.9383039474487305,3.935231924057007,3.940351963043213,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 +rope_paper,liger,backward,speed,ms,T,sequence length,1024,0.2682879865169525,0.2662400007247925,0.27033600211143494,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 +rope_paper,liger,backward,speed,ms,T,sequence length,2048,0.49663999676704407,0.49459201097488403,0.4986880123615265,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 +rope_paper,liger,backward,speed,ms,T,sequence length,4096,0.9523199796676636,0.9502720236778259,0.9553920030593872,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 +rope_paper,liger,backward,speed,ms,T,sequence length,8192,1.8626559972763062,1.8595839738845825,1.8657280206680298,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 +rope_paper,liger,backward,speed,ms,T,sequence length,16384,3.680255889892578,3.676774501800537,3.684351921081543,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 +rope_paper,huggingface,backward,speed,ms,T,sequence length,1024,0.37068799138069153,0.3696640133857727,0.37171199917793274,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 +rope_paper,huggingface,backward,speed,ms,T,sequence length,2048,0.7311360239982605,0.7301120162010193,0.7321599721908569,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 +rope_paper,huggingface,backward,speed,ms,T,sequence length,4096,1.3957120180130005,1.3946880102157593,1.3967360258102417,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 +rope_paper,huggingface,backward,speed,ms,T,sequence length,8192,2.751487970352173,2.7494399547576904,2.7535359859466553,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 +rope_paper,huggingface,backward,speed,ms,T,sequence length,16384,5.410816192626953,5.40774393081665,5.413887977600098,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 +rope_paper,liger,full,speed,ms,T,sequence length,1024,0.5478399991989136,0.5457919836044312,0.5509120225906372,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 +rope_paper,liger,full,speed,ms,T,sequence length,2048,1.0024960041046143,1.0004479885101318,1.0061824321746826,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 +rope_paper,liger,full,speed,ms,T,sequence length,4096,1.9169280529022217,1.913856029510498,1.921023964881897,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 +rope_paper,liger,full,speed,ms,T,sequence length,8192,3.742719888687134,3.7396481037139893,3.74783992767334,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 +rope_paper,liger,full,speed,ms,T,sequence length,16384,7.387135982513428,7.383449554443359,7.389798164367676,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 +rope_paper,huggingface,full,speed,ms,T,sequence length,1024,0.6563839912414551,0.6553599834442139,0.6574079990386963,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 +rope_paper,huggingface,full,speed,ms,T,sequence length,2048,1.264639973640442,1.2636159658432007,1.265663981437683,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 +rope_paper,huggingface,full,speed,ms,T,sequence length,4096,2.411520004272461,2.4094719886779785,2.412544012069702,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 +rope_paper,huggingface,full,speed,ms,T,sequence length,8192,4.7472639083862305,4.745215892791748,4.750336170196533,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 +rope_paper,huggingface,full,speed,ms,T,sequence length,16384,9.334783554077148,9.330893516540527,9.336832046508789,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 +rope_paper,liger,full,memory,MB,T,sequence length,1024,42.0,42.0,42.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 +rope_paper,liger,full,memory,MB,T,sequence length,2048,84.0,84.0,84.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 +rope_paper,liger,full,memory,MB,T,sequence length,4096,168.0,168.0,168.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 +rope_paper,liger,full,memory,MB,T,sequence length,8192,336.0,336.0,336.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 +rope_paper,liger,full,memory,MB,T,sequence length,16384,672.0,672.0,672.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 +rope_paper,huggingface,full,memory,MB,T,sequence length,1024,114.5,114.5,114.5,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 +rope_paper,huggingface,full,memory,MB,T,sequence length,2048,229.0,229.0,229.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 +rope_paper,huggingface,full,memory,MB,T,sequence length,4096,458.0,458.0,458.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 +rope_paper,huggingface,full,memory,MB,T,sequence length,8192,916.0,916.0,916.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 +rope_paper,huggingface,full,memory,MB,T,sequence length,16384,1832.0,1832.0,1832.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 diff --git a/benchmark/scripts/benchmark_rope_paper.py b/benchmark/scripts/benchmark_rope_paper.py new file mode 100644 index 000000000..b5e615f49 --- /dev/null +++ b/benchmark/scripts/benchmark_rope_paper.py @@ -0,0 +1,237 @@ +import torch +import triton +from transformers.models.roformer.modeling_roformer import ( + RoFormerSelfAttention, + RoFormerSinusoidalPositionalEmbedding, +) +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.rope_paper import liger_rotary_paper_pos_emb +from liger_kernel.utils import infer_device + +device = infer_device() +apply_rotary_pos_emb = RoFormerSelfAttention.apply_rotary_position_embeddings + + +def bench_speed_rope_paper(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + mode = input.kernel_operation_mode + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = ( + extra_benchmark_config["hidden_size"] + if "hidden_size" in extra_benchmark_config + else input.x + ) + seq_len = ( + extra_benchmark_config["seq_len"] + if "seq_len" in extra_benchmark_config + else input.x + ) + + head_dim = hidden_size // num_q_heads + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device + ) + + sinusoidal_pos = rotary_emb((1, seq_len))[None, :, :].to(dtype) + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + + def fwd(): + if provider == "liger": + return liger_rotary_paper_pos_emb(q, k, cos, sin) + elif provider == "huggingface": + return apply_rotary_pos_emb(sinusoidal_pos[None, :, :, :], q, k) + else: + raise ValueError(f"Invalid provider: {provider} for RoPE paper embedding") + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "backward": + q_out, k_out = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: torch.autograd.grad( + (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True + ), + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + q_out, k_out = fwd() + torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + grad_to_none=[q, k], + rep=400, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_rope_paper(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + provider = input.kernel_provider + + extra_benchmark_config = input.extra_benchmark_config + num_q_heads = extra_benchmark_config["num_q_heads"] + num_kv_heads = extra_benchmark_config["num_kv_heads"] + dtype = extra_benchmark_config["dtype"] + + # x can be either hidden_size or seq_len + hidden_size = ( + extra_benchmark_config["hidden_size"] + if "hidden_size" in extra_benchmark_config + else input.x + ) + seq_len = ( + extra_benchmark_config["seq_len"] + if "seq_len" in extra_benchmark_config + else input.x + ) + + head_dim = hidden_size // num_q_heads + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + q = torch.randn( + (1, seq_len, num_q_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + k = torch.randn( + (1, seq_len, num_kv_heads, head_dim), + device=device, + requires_grad=True, + dtype=dtype, + ).transpose(1, 2) + dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( + k, device=device + ) + + sinusoidal_pos = rotary_emb((1, seq_len))[None, :, :].to(dtype) + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + + def full(): + if provider == "liger": + q_out, k_out = liger_rotary_paper_pos_emb(q, k, cos, sin) + else: + q_out, k_out = apply_rotary_pos_emb(sinusoidal_pos[None, :, :, :], q, k) + torch.autograd.grad( + (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True + ) + + mem_50, mem_20, mem_80 = _test_memory( + full, + quantiles=QUANTILES, + ) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs_varying_hidden_size = { + "kernel_name": "rope_paper", + "x_name": "H", + "x_label": "hidden size", + "x_values": [32 * (2**i) for i in range(4, 10, 2)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "seq_len": 2048, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_rope_paper, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_hidden_size, + ) + run_benchmarks( + bench_test_fn=bench_memory_rope_paper, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_hidden_size, + ) + + common_configs_varying_seq_len = { + "kernel_name": "rope_paper", + "x_name": "T", + "x_label": "sequence length", + "x_values": [2**i for i in range(10, 15)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [ + { + "dtype": torch.bfloat16, + "hidden_size": 8192, + "num_q_heads": 32, + "num_kv_heads": 8, + } + ], + "overwrite": args.overwrite, + } + run_benchmarks( + bench_test_fn=bench_speed_rope_paper, + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs_varying_seq_len, + ) + run_benchmarks( + bench_test_fn=bench_memory_rope_paper, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs_varying_seq_len, + ) From 571e7daa591dd2c20c01181ca1703cf4966eeb07 Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Wed, 11 Dec 2024 17:17:34 +0800 Subject: [PATCH 4/8] Add paper-form option to Liger Kernel's RoPE implementation --- src/liger_kernel/ops/rope.py | 275 ++++++++++++++++---- src/liger_kernel/transformers/functional.py | 10 +- src/liger_kernel/transformers/rope.py | 9 +- 3 files changed, 238 insertions(+), 56 deletions(-) diff --git a/src/liger_kernel/ops/rope.py b/src/liger_kernel/ops/rope.py index 0cd88efeb..872802fed 100644 --- a/src/liger_kernel/ops/rope.py +++ b/src/liger_kernel/ops/rope.py @@ -117,7 +117,126 @@ def _triton_rope( tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) -def rope_forward(q, k, cos, sin): +@triton.jit +def _triton_rope_paper( + q_ptr, + q_row_stride, + k_ptr, + k_row_stride, + cos, + cos_row_stride, + sin, + sin_row_stride, + sl, + bs: tl.constexpr, + n_qh: tl.constexpr, + n_kh: tl.constexpr, + hd: tl.constexpr, + pad_n_qh: tl.constexpr, + pad_n_kh: tl.constexpr, + pad_hd: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BACKWARD_PASS: tl.constexpr = False, +): + # q size: (bsz, seq_len, num_q_heads, head_dim) + # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) + # k size: (bsz, seq_len, num_kv_heads, head_dim) + # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) + + # cos size: (1, seq_len, head_dim // 2) + # stride: (seq_len * head_dim, head_dim, 1) + pid = tl.program_id(0) + + # locate start address + q_ptr = q_ptr + pid * q_row_stride + k_ptr = k_ptr + pid * k_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index + # and pid % sl to get the sequence index. + # 2. We only need the left half of cos and sin matrix because the right half is just + # a clone of the left half. + cos_row_idx = pid % (sl) + cos = cos + cos_row_idx * cos_row_stride + sin = sin + cos_row_idx * sin_row_stride + cos_offsets = tl.arange(0, pad_hd // 2) + cos_mask = cos_offsets < hd // 2 + cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) + sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) + + # #################################################################### + # Load the even-indexed and odd-indexed elements of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # even-indexed elements of the head + even_q_offsets = ( + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 + ) + even_k_offsets = ( + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 + ) + even_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 < hd + ) + even_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 < hd + ) + q_tile_even = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_even = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to( + sin_row.dtype + ) + + # odd-indexed elements of the head + odd_q_offsets = even_q_offsets + 1 + odd_k_offsets = even_k_offsets + 1 + odd_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd + ) + odd_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( + tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd + ) + q_tile_odd = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to( + sin_row.dtype + ) + k_tile_odd = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to( + sin_row.dtype + ) + + if not BACKWARD_PASS: + # y_even = x_even * cos - x_odd * sin + # y_odd = x_odd * cos + x_even * sin + new_q_tile_even = q_tile_even * cos_row - q_tile_odd * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) + new_q_tile_odd = q_tile_odd * cos_row + q_tile_even * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) + + new_k_tile_even = k_tile_even * cos_row - k_tile_odd * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) + new_k_tile_odd = k_tile_odd * cos_row + k_tile_even * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) + else: + # dy_even = dx_even * cos + dx_odd * sin + # dy_odd = dx_odd * cos - dx_even * sin + new_q_tile_even = q_tile_even * cos_row + q_tile_odd * sin_row + tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) + new_q_tile_odd = q_tile_odd * cos_row - q_tile_even * sin_row + tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) + + new_k_tile_even = k_tile_even * cos_row + k_tile_odd * sin_row + tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) + new_k_tile_odd = k_tile_odd * cos_row - k_tile_even * sin_row + tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) + + +def rope_forward(q, k, cos, sin, paper_form): # transpose it back to the physical shape because Triton looks at the physical storage # note: q and k are incontiguous before the transformation and will become contiguous after transpose @@ -139,30 +258,52 @@ def rope_forward(q, k, cos, sin): cos = cos.contiguous() sin = sin.contiguous() - _triton_rope[(n_row,)]( - q, - q.stride(1), - k, - k.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=False, - ) + if not paper_form: + _triton_rope[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) + else: + _triton_rope_paper[(n_row,)]( + q, + q.stride(1), + k, + k.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=False, + ) return q.transpose(1, 2), k.transpose(1, 2), cos, sin -def rope_backward(dq, dk, cos, sin): +def rope_backward(dq, dk, cos, sin, paper_form): dq = dq.transpose(1, 2) dk = dk.transpose(1, 2) @@ -180,51 +321,85 @@ def rope_backward(dq, dk, cos, sin): dk = dk.contiguous() # backward is similar to forward except swapping few ops - _triton_rope[(n_row,)]( - dq, - dq.stride(1), - dk, - dk.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=True, - ) + if not paper_form: + _triton_rope[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) + else: + _triton_rope_paper[(n_row,)]( + dq, + dq.stride(1), + dk, + dk.stride(1), + cos, + cos.stride(-2), + sin, + sin.stride(-2), + seq_len, + batch_size, + n_q_head, + n_kv_head, + head_dim, + pad_n_q_head, + pad_n_kv_head, + pad_hd, + BLOCK_SIZE=BLOCK_SIZE, + BACKWARD_PASS=True, + ) return dq.transpose(1, 2), dk.transpose(1, 2) class LigerRopeFunction(torch.autograd.Function): """ - Triton implementation of the Rotary Positional Embedding (RoPE) operation. Please note that - this implements the HuggingFace Llama & Mistral version, whose rotation matrix is slightly different - than the original RoPE paper. + Triton implementation of the Rotary Positional Embedding (RoPE) operation. + This implements both HuggingFace Llama & Mistral version and the original RoPE paper version. - Please find the corresponding HuggingFace implementation here: + Please find the corresponding HuggingFace Llama & Mistral implementation here: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llama/modeling_llama.py#L184 + Please find the corresponding HuggingFace paper-form implementation here: + https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/roformer/modeling_roformer.py#L309 + For more details about the rotation matrix used here, please refer to: https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/2 """ @staticmethod - def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + def forward( + ctx, + q, + k, + cos, + sin, + position_ids=None, + unsqueeze_dim=1, + paper_form: bool = False, + ): """ q size: (bsz, n_q_head, seq_len, head_dim) k size: (bsz, n_kv_head, seq_len, head_dim) cos size: (1, seq_len, head_dim) sin size: (1, seq_len, head_dim) """ - q, k, cos, sin = rope_forward(q, k, cos, sin) + q, k, cos, sin = rope_forward(q, k, cos, sin, paper_form) + ctx.paper_form = paper_form ctx.save_for_backward(cos, sin) return q, k @@ -237,5 +412,5 @@ def backward(ctx, dq, dk): """ cos, sin = ctx.saved_tensors - dq, dk = rope_backward(dq, dk, cos, sin) - return dq, dk, None, None, None, None + dq, dk = rope_backward(dq, dk, cos, sin, ctx.paper_form) + return dq, dk, None, None, None, None, None diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 6151cccbf..e776163f9 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -166,10 +166,12 @@ def liger_rms_norm( return LigerRMSNormFunction.apply(X, W, eps, offset, casting_mode, in_place) -def liger_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) - - +def liger_rope( + q, k, cos, sin, position_ids=None, unsqueeze_dim=1, paper_form: bool = False +): + return LigerRopeFunction.apply( + q, k, cos, sin, position_ids, unsqueeze_dim, paper_form + ) def liger_rope_paper(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return LigerRopePaperFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) diff --git a/src/liger_kernel/transformers/rope.py b/src/liger_kernel/transformers/rope.py index a40b29af3..8b7ea8e61 100644 --- a/src/liger_kernel/transformers/rope.py +++ b/src/liger_kernel/transformers/rope.py @@ -1,7 +1,9 @@ from liger_kernel.ops.rope import LigerRopeFunction -def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def liger_rotary_pos_emb( + q, k, cos, sin, position_ids=None, unsqueeze_dim=1, paper_form: bool = False +): """ Applies Rotary Positional Embedding (RoPE) operation to query and key states. @@ -12,9 +14,12 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. + paper_form (bool, optional): Whether to use the paper-form RoPE rotary matrix. Defaults to false. Returns: Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. """ - return LigerRopeFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) + return LigerRopeFunction.apply( + q, k, cos, sin, position_ids, unsqueeze_dim, paper_form + ) From 51ff8ddbd188d889c0e0275f2aab7043ddb4366e Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Wed, 11 Dec 2024 17:21:26 +0800 Subject: [PATCH 5/8] Add Unit Tests for the paper-form option in RoPE --- test/transformers/test_rope.py | 77 +++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 15 deletions(-) diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 74080b57f..7f341a6d0 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -6,12 +6,19 @@ LlamaRotaryEmbedding, apply_rotary_pos_emb, ) +from transformers.models.roformer.modeling_roformer import ( + RoFormerSelfAttention, + RoFormerSinusoidalPositionalEmbedding, +) from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.transformers.functional import liger_rope from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.utils import infer_device +apply_paper_rotary_pos_emb = RoFormerSelfAttention.apply_rotary_position_embeddings + + device = infer_device() SLEEP_SECONDS = 0.1 @@ -46,10 +53,16 @@ ), ], ) +@pytest.mark.parametrize( + "paper_form", + [ + False, + True, + ], +) def test_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol, paper_form ): - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) _tensor_q = ( torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) @@ -69,12 +82,25 @@ def test_correctness( q2 = _tensor_q.clone().requires_grad_(True) k2 = _tensor_k.clone().requires_grad_(True) - pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - cos, sin = rotary_emb(k1, pos_ids) - # validate forward pass - hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids) - tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin) + if not paper_form: + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k1, pos_ids) + + hf_q, hf_k = apply_rotary_pos_emb(q1, k1, cos, sin, pos_ids) + else: + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + + sinusoidal_pos = rotary_emb((bsz, seq_len)).unsqueeze(0).to(dtype) + + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + hf_q, hf_k = apply_paper_rotary_pos_emb(sinusoidal_pos.unsqueeze(0), q1, k1) + + tt_q, tt_k = liger_rotary_pos_emb(q2, k2, cos, sin, paper_form=paper_form) assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol) assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol) @@ -111,8 +137,15 @@ def test_correctness( (torch.bfloat16, 1e-1, 1e-5), ], ) +@pytest.mark.parametrize( + "paper_form", + [ + False, + True, + ], +) def test_functional_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol, paper_form ): _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) @@ -123,13 +156,27 @@ def test_functional_correctness( k1 = _k.clone().requires_grad_(True) k2 = _k.clone().requires_grad_(True) - rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) - - pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) - cos, sin = rotary_emb(k1, pos_ids) - - functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin) - class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin) + if not paper_form: + rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) + + pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k1, pos_ids) + else: + rotary_emb = RoFormerSinusoidalPositionalEmbedding( + num_positions=seq_len, embedding_dim=head_dim + ).to(device) + sinusoidal_pos = rotary_emb((bsz, seq_len)).unsqueeze(0).to(dtype) + + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + + functional_q, functional_k = liger_rope( + q=q1, + k=k1, + cos=cos, + sin=sin, + paper_form=paper_form, + ) + class_q, class_k = LigerRopeFunction.apply(q2, k2, cos, sin, None, 1, paper_form) assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol) assert torch.allclose(functional_k, class_k, atol=atol, rtol=rtol) From 805dae96126eb32097f38c17451a0691639575da Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Wed, 11 Dec 2024 17:32:19 +0800 Subject: [PATCH 6/8] Cleanup --- benchmark/data/all_benchmark_data.csv | 66 +----- benchmark/scripts/benchmark_rope_paper.py | 237 ------------------- src/liger_kernel/ops/rope_paper.py | 241 -------------------- src/liger_kernel/transformers/__init__.py | 3 - src/liger_kernel/transformers/functional.py | 1 - src/liger_kernel/transformers/rope_paper.py | 20 -- test/transformers/test_rope_paper.py | 163 ------------- 7 files changed, 1 insertion(+), 730 deletions(-) delete mode 100644 benchmark/scripts/benchmark_rope_paper.py delete mode 100644 src/liger_kernel/ops/rope_paper.py delete mode 100644 src/liger_kernel/transformers/rope_paper.py delete mode 100644 test/transformers/test_rope_paper.py diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 27650665b..398e03e01 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -714,68 +714,4 @@ fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 -fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 -rope_paper,liger,forward,speed,ms,H,hidden size,512,0.027648000046610832,0.027648000046610832,0.028672000393271446,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:14,0.3.1 -rope_paper,liger,forward,speed,ms,H,hidden size,2048,0.1515520066022873,0.15052799880504608,0.15360000729560852,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:14,0.3.1 -rope_paper,liger,forward,speed,ms,H,hidden size,8192,0.5099520087242126,0.5079039931297302,0.5120000243186951,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:14,0.3.1 -rope_paper,huggingface,forward,speed,ms,H,hidden size,512,0.12800000607967377,0.12492799758911133,0.13209599256515503,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:16,0.3.1 -rope_paper,huggingface,forward,speed,ms,H,hidden size,2048,0.17203199863433838,0.17100800573825836,0.17203199863433838,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:16,0.3.1 -rope_paper,huggingface,forward,speed,ms,H,hidden size,8192,0.5396479964256287,0.5386239886283875,0.5416960120201111,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:16,0.3.1 -rope_paper,liger,backward,speed,ms,H,hidden size,512,0.021503999829292297,0.01945599913597107,0.02457600086927414,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:18,0.3.1 -rope_paper,liger,backward,speed,ms,H,hidden size,2048,0.13926400244235992,0.1382399946451187,0.14028799533843994,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:18,0.3.1 -rope_paper,liger,backward,speed,ms,H,hidden size,8192,0.49561598896980286,0.4935680031776428,0.4976640045642853,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:18,0.3.1 -rope_paper,huggingface,backward,speed,ms,H,hidden size,512,0.22732800245285034,0.22466561198234558,0.2314240038394928,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:20,0.3.1 -rope_paper,huggingface,backward,speed,ms,H,hidden size,2048,0.20787200331687927,0.20684799551963806,0.20787200331687927,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:20,0.3.1 -rope_paper,huggingface,backward,speed,ms,H,hidden size,8192,0.7290880084037781,0.7290880084037781,0.7301120162010193,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:20,0.3.1 -rope_paper,liger,full,speed,ms,H,hidden size,512,0.14233599603176117,0.10444799810647964,0.14622725546360016,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:22,0.3.1 -rope_paper,liger,full,speed,ms,H,hidden size,2048,0.28672000765800476,0.28467199206352234,0.2887679934501648,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:22,0.3.1 -rope_paper,liger,full,speed,ms,H,hidden size,8192,1.001471996307373,0.9983999729156494,1.0045440196990967,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:22,0.3.1 -rope_paper,huggingface,full,speed,ms,H,hidden size,512,0.44441598653793335,0.4413439929485321,0.45977601408958435,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:24,0.3.1 -rope_paper,huggingface,full,speed,ms,H,hidden size,2048,0.4249599874019623,0.42393600940704346,0.4280320107936859,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:24,0.3.1 -rope_paper,huggingface,full,speed,ms,H,hidden size,8192,1.2636159658432007,1.2625919580459595,1.265663981437683,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:24,0.3.1 -rope_paper,liger,full,memory,MB,H,hidden size,512,5.25,5.25,5.25,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:25,0.3.1 -rope_paper,liger,full,memory,MB,H,hidden size,2048,21.0,21.0,21.0,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:25,0.3.1 -rope_paper,liger,full,memory,MB,H,hidden size,8192,84.0,84.0,84.0,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:25,0.3.1 -rope_paper,huggingface,full,memory,MB,H,hidden size,512,14.3125,14.3125,14.3125,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:26,0.3.1 -rope_paper,huggingface,full,memory,MB,H,hidden size,2048,57.25,57.25,57.25,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:26,0.3.1 -rope_paper,huggingface,full,memory,MB,H,hidden size,8192,229.0,229.0,229.0,"{""dtype"": ""torch.bfloat16"", ""seq_len"": 2048, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:26,0.3.1 -rope_paper,liger,forward,speed,ms,T,sequence length,1024,0.2836480140686035,0.2815999984741211,0.28569599986076355,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 -rope_paper,liger,forward,speed,ms,T,sequence length,2048,0.5089280009269714,0.506879985332489,0.5120000243186951,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 -rope_paper,liger,forward,speed,ms,T,sequence length,4096,0.9666560292243958,0.9646080136299133,0.9697279930114746,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 -rope_paper,liger,forward,speed,ms,T,sequence length,8192,1.8821120262145996,1.8800640106201172,1.8851840496063232,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 -rope_paper,liger,forward,speed,ms,T,sequence length,16384,3.7099521160125732,3.7058560848236084,3.7130239009857178,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:37,0.3.1 -rope_paper,huggingface,forward,speed,ms,T,sequence length,1024,0.289792001247406,0.2887679934501648,0.289792001247406,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 -rope_paper,huggingface,forward,speed,ms,T,sequence length,2048,0.5396479964256287,0.5386239886283875,0.5416960120201111,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 -rope_paper,huggingface,forward,speed,ms,T,sequence length,4096,1.0240000486373901,1.0219520330429077,1.026047945022583,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 -rope_paper,huggingface,forward,speed,ms,T,sequence length,8192,1.9967999458312988,1.994752049446106,1.9988479614257812,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 -rope_paper,huggingface,forward,speed,ms,T,sequence length,16384,3.9383039474487305,3.935231924057007,3.940351963043213,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:15:48,0.3.1 -rope_paper,liger,backward,speed,ms,T,sequence length,1024,0.2682879865169525,0.2662400007247925,0.27033600211143494,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 -rope_paper,liger,backward,speed,ms,T,sequence length,2048,0.49663999676704407,0.49459201097488403,0.4986880123615265,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 -rope_paper,liger,backward,speed,ms,T,sequence length,4096,0.9523199796676636,0.9502720236778259,0.9553920030593872,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 -rope_paper,liger,backward,speed,ms,T,sequence length,8192,1.8626559972763062,1.8595839738845825,1.8657280206680298,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 -rope_paper,liger,backward,speed,ms,T,sequence length,16384,3.680255889892578,3.676774501800537,3.684351921081543,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:00,0.3.1 -rope_paper,huggingface,backward,speed,ms,T,sequence length,1024,0.37068799138069153,0.3696640133857727,0.37171199917793274,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 -rope_paper,huggingface,backward,speed,ms,T,sequence length,2048,0.7311360239982605,0.7301120162010193,0.7321599721908569,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 -rope_paper,huggingface,backward,speed,ms,T,sequence length,4096,1.3957120180130005,1.3946880102157593,1.3967360258102417,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 -rope_paper,huggingface,backward,speed,ms,T,sequence length,8192,2.751487970352173,2.7494399547576904,2.7535359859466553,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 -rope_paper,huggingface,backward,speed,ms,T,sequence length,16384,5.410816192626953,5.40774393081665,5.413887977600098,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:11,0.3.1 -rope_paper,liger,full,speed,ms,T,sequence length,1024,0.5478399991989136,0.5457919836044312,0.5509120225906372,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 -rope_paper,liger,full,speed,ms,T,sequence length,2048,1.0024960041046143,1.0004479885101318,1.0061824321746826,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 -rope_paper,liger,full,speed,ms,T,sequence length,4096,1.9169280529022217,1.913856029510498,1.921023964881897,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 -rope_paper,liger,full,speed,ms,T,sequence length,8192,3.742719888687134,3.7396481037139893,3.74783992767334,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 -rope_paper,liger,full,speed,ms,T,sequence length,16384,7.387135982513428,7.383449554443359,7.389798164367676,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:22,0.3.1 -rope_paper,huggingface,full,speed,ms,T,sequence length,1024,0.6563839912414551,0.6553599834442139,0.6574079990386963,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 -rope_paper,huggingface,full,speed,ms,T,sequence length,2048,1.264639973640442,1.2636159658432007,1.265663981437683,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 -rope_paper,huggingface,full,speed,ms,T,sequence length,4096,2.411520004272461,2.4094719886779785,2.412544012069702,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 -rope_paper,huggingface,full,speed,ms,T,sequence length,8192,4.7472639083862305,4.745215892791748,4.750336170196533,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 -rope_paper,huggingface,full,speed,ms,T,sequence length,16384,9.334783554077148,9.330893516540527,9.336832046508789,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:33,0.3.1 -rope_paper,liger,full,memory,MB,T,sequence length,1024,42.0,42.0,42.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 -rope_paper,liger,full,memory,MB,T,sequence length,2048,84.0,84.0,84.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 -rope_paper,liger,full,memory,MB,T,sequence length,4096,168.0,168.0,168.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 -rope_paper,liger,full,memory,MB,T,sequence length,8192,336.0,336.0,336.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 -rope_paper,liger,full,memory,MB,T,sequence length,16384,672.0,672.0,672.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:42,0.3.1 -rope_paper,huggingface,full,memory,MB,T,sequence length,1024,114.5,114.5,114.5,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 -rope_paper,huggingface,full,memory,MB,T,sequence length,2048,229.0,229.0,229.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 -rope_paper,huggingface,full,memory,MB,T,sequence length,4096,458.0,458.0,458.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 -rope_paper,huggingface,full,memory,MB,T,sequence length,8192,916.0,916.0,916.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 -rope_paper,huggingface,full,memory,MB,T,sequence length,16384,1832.0,1832.0,1832.0,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA A100 80GB PCIe,2024-12-10 21:16:51,0.3.1 +fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1 \ No newline at end of file diff --git a/benchmark/scripts/benchmark_rope_paper.py b/benchmark/scripts/benchmark_rope_paper.py deleted file mode 100644 index b5e615f49..000000000 --- a/benchmark/scripts/benchmark_rope_paper.py +++ /dev/null @@ -1,237 +0,0 @@ -import torch -import triton -from transformers.models.roformer.modeling_roformer import ( - RoFormerSelfAttention, - RoFormerSinusoidalPositionalEmbedding, -) -from utils import ( - QUANTILES, - SingleBenchmarkRunInput, - SingleBenchmarkRunOutput, - _test_memory, - parse_benchmark_script_args, - run_benchmarks, -) - -from liger_kernel.transformers.rope_paper import liger_rotary_paper_pos_emb -from liger_kernel.utils import infer_device - -device = infer_device() -apply_rotary_pos_emb = RoFormerSelfAttention.apply_rotary_position_embeddings - - -def bench_speed_rope_paper(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider - mode = input.kernel_operation_mode - - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] - - # x can be either hidden_size or seq_len - hidden_size = ( - extra_benchmark_config["hidden_size"] - if "hidden_size" in extra_benchmark_config - else input.x - ) - seq_len = ( - extra_benchmark_config["seq_len"] - if "seq_len" in extra_benchmark_config - else input.x - ) - - head_dim = hidden_size // num_q_heads - rotary_emb = RoFormerSinusoidalPositionalEmbedding( - num_positions=seq_len, embedding_dim=head_dim - ).to(device) - q = torch.randn( - (1, seq_len, num_q_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - k = torch.randn( - (1, seq_len, num_kv_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( - k, device=device - ) - - sinusoidal_pos = rotary_emb((1, seq_len))[None, :, :].to(dtype) - sin, cos = sinusoidal_pos.chunk(2, dim=-1) - - def fwd(): - if provider == "liger": - return liger_rotary_paper_pos_emb(q, k, cos, sin) - elif provider == "huggingface": - return apply_rotary_pos_emb(sinusoidal_pos[None, :, :, :], q, k) - else: - raise ValueError(f"Invalid provider: {provider} for RoPE paper embedding") - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) - elif mode == "backward": - q_out, k_out = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: torch.autograd.grad( - (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True - ), - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) - elif mode == "full": - - def full(): - q_out, k_out = fwd() - torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True) - - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - grad_to_none=[q, k], - rep=400, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - - -def bench_memory_rope_paper(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - provider = input.kernel_provider - - extra_benchmark_config = input.extra_benchmark_config - num_q_heads = extra_benchmark_config["num_q_heads"] - num_kv_heads = extra_benchmark_config["num_kv_heads"] - dtype = extra_benchmark_config["dtype"] - - # x can be either hidden_size or seq_len - hidden_size = ( - extra_benchmark_config["hidden_size"] - if "hidden_size" in extra_benchmark_config - else input.x - ) - seq_len = ( - extra_benchmark_config["seq_len"] - if "seq_len" in extra_benchmark_config - else input.x - ) - - head_dim = hidden_size // num_q_heads - rotary_emb = RoFormerSinusoidalPositionalEmbedding( - num_positions=seq_len, embedding_dim=head_dim - ).to(device) - q = torch.randn( - (1, seq_len, num_q_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - k = torch.randn( - (1, seq_len, num_kv_heads, head_dim), - device=device, - requires_grad=True, - dtype=dtype, - ).transpose(1, 2) - dq, dk = torch.randn_like(q, device=device, dtype=dtype), torch.randn_like( - k, device=device - ) - - sinusoidal_pos = rotary_emb((1, seq_len))[None, :, :].to(dtype) - sin, cos = sinusoidal_pos.chunk(2, dim=-1) - - def full(): - if provider == "liger": - q_out, k_out = liger_rotary_paper_pos_emb(q, k, cos, sin) - else: - q_out, k_out = apply_rotary_pos_emb(sinusoidal_pos[None, :, :, :], q, k) - torch.autograd.grad( - (q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True - ) - - mem_50, mem_20, mem_80 = _test_memory( - full, - quantiles=QUANTILES, - ) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) - - -if __name__ == "__main__": - args = parse_benchmark_script_args() - - common_configs_varying_hidden_size = { - "kernel_name": "rope_paper", - "x_name": "H", - "x_label": "hidden size", - "x_values": [32 * (2**i) for i in range(4, 10, 2)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "seq_len": 2048, - "num_q_heads": 32, - "num_kv_heads": 8, - } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_rope_paper, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_hidden_size, - ) - run_benchmarks( - bench_test_fn=bench_memory_rope_paper, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_hidden_size, - ) - - common_configs_varying_seq_len = { - "kernel_name": "rope_paper", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, 15)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "dtype": torch.bfloat16, - "hidden_size": 8192, - "num_q_heads": 32, - "num_kv_heads": 8, - } - ], - "overwrite": args.overwrite, - } - run_benchmarks( - bench_test_fn=bench_speed_rope_paper, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs_varying_seq_len, - ) - run_benchmarks( - bench_test_fn=bench_memory_rope_paper, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs_varying_seq_len, - ) diff --git a/src/liger_kernel/ops/rope_paper.py b/src/liger_kernel/ops/rope_paper.py deleted file mode 100644 index aca347ed5..000000000 --- a/src/liger_kernel/ops/rope_paper.py +++ /dev/null @@ -1,241 +0,0 @@ -import torch -import triton -import triton.language as tl - - -@triton.jit -def _triton_rope_paper( - q_ptr, - q_row_stride, - k_ptr, - k_row_stride, - cos, - cos_row_stride, - sin, - sin_row_stride, - sl, - bs: tl.constexpr, - n_qh: tl.constexpr, - n_kh: tl.constexpr, - hd: tl.constexpr, - pad_n_qh: tl.constexpr, - pad_n_kh: tl.constexpr, - pad_hd: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - BACKWARD_PASS: tl.constexpr = False, -): - # q size: (bsz, seq_len, num_q_heads, head_dim) - # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) - # k size: (bsz, seq_len, num_kv_heads, head_dim) - # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) - - # cos size: (1, seq_len, head_dim // 2) - # stride: (seq_len * head_dim, head_dim, 1) - pid = tl.program_id(0) - - # locate start address - q_ptr = q_ptr + pid * q_row_stride - k_ptr = k_ptr + pid * k_row_stride - - # #################################################################### - # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position - # m of this program instance - # #################################################################### - - # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which - # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension - # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index - # and pid % sl to get the sequence index. - # 2. cos and sin matrices are already in the shape (1, seq_len, head_dim // 2), so we simply load the entire matrix. - cos_row_idx = pid % (sl) - cos = cos + cos_row_idx * cos_row_stride - sin = sin + cos_row_idx * sin_row_stride - cos_offsets = tl.arange(0, pad_hd // 2) - cos_mask = cos_offsets < hd // 2 - cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) - sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) - - # #################################################################### - # Load the even-indexed and odd-indexed elements of q and k for the current - # program instance (i.e. for the current token) separately - # #################################################################### - # even-indexed elements of the head - even_q_offsets = ( - tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 - ) - even_k_offsets = ( - tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] * 2 - ) - even_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( - tl.arange(0, pad_hd // 2)[None, :] * 2 < hd - ) - even_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( - tl.arange(0, pad_hd // 2)[None, :] * 2 < hd - ) - q_tile_even = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_even = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to( - sin_row.dtype - ) - - # odd-indexed elements of the head - odd_q_offsets = even_q_offsets + 1 - odd_k_offsets = even_k_offsets + 1 - odd_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( - tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd - ) - odd_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( - tl.arange(0, pad_hd // 2)[None, :] * 2 + 1 < hd - ) - q_tile_odd = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to( - sin_row.dtype - ) - k_tile_odd = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to( - sin_row.dtype - ) - - if not BACKWARD_PASS: - # y_even = x_even * cos - x_odd * sin - # y_odd = x_odd * cos + x_even * sin - new_q_tile_even = q_tile_even * cos_row - q_tile_odd * sin_row - tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) - new_q_tile_odd = q_tile_odd * cos_row + q_tile_even * sin_row - tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) - - new_k_tile_even = k_tile_even * cos_row - k_tile_odd * sin_row - tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) - new_k_tile_odd = k_tile_odd * cos_row + k_tile_even * sin_row - tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) - else: - # dy_even = dx_even * cos + dx_odd * sin - # dy_odd = dx_odd * cos - dx_even * sin - new_q_tile_even = q_tile_even * cos_row + q_tile_odd * sin_row - tl.store(q_ptr + even_q_offsets, new_q_tile_even, mask=even_q_mask) - new_q_tile_odd = q_tile_odd * cos_row - q_tile_even * sin_row - tl.store(q_ptr + odd_q_offsets, new_q_tile_odd, mask=odd_q_mask) - - new_k_tile_even = k_tile_even * cos_row + k_tile_odd * sin_row - tl.store(k_ptr + even_k_offsets, new_k_tile_even, mask=even_k_mask) - new_k_tile_odd = k_tile_odd * cos_row - k_tile_even * sin_row - tl.store(k_ptr + odd_k_offsets, new_k_tile_odd, mask=odd_k_mask) - - -def rope_paper_forward(q, k, cos, sin): - - # transpose it back to the physical shape because Triton looks at the physical storage - # note: q and k are incontiguous before the transformation and will become contiguous after transpose - q = q.transpose(1, 2) - k = k.transpose(1, 2) - - batch_size, seq_len, n_q_head, head_dim = q.shape - n_kv_head = k.shape[2] - pad_hd = triton.next_power_of_2(head_dim) - pad_n_q_head = triton.next_power_of_2(n_q_head) - pad_n_kv_head = triton.next_power_of_2(n_kv_head) - BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) - - n_row = batch_size * seq_len - - # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous - q = q.contiguous() - k = k.contiguous() - cos = cos.contiguous() - sin = sin.contiguous() - - _triton_rope_paper[(n_row,)]( - q, - q.stride(1), - k, - k.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=False, - ) - return q.transpose(1, 2), k.transpose(1, 2), cos, sin - - -def rope_paper_backward(dq, dk, cos, sin): - dq = dq.transpose(1, 2) - dk = dk.transpose(1, 2) - - batch_size, seq_len, n_q_head, head_dim = dq.shape - n_kv_head = dk.shape[2] - pad_hd = triton.next_power_of_2(head_dim) - pad_n_q_head = triton.next_power_of_2(n_q_head) - pad_n_kv_head = triton.next_power_of_2(n_kv_head) - BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) - - n_row = batch_size * seq_len - - # ensure dq and dk are contiguous - dq = dq.contiguous() - dk = dk.contiguous() - - # backward is similar to forward except swapping few ops - _triton_rope_paper[(n_row,)]( - dq, - dq.stride(1), - dk, - dk.stride(1), - cos, - cos.stride(-2), - sin, - sin.stride(-2), - seq_len, - batch_size, - n_q_head, - n_kv_head, - head_dim, - pad_n_q_head, - pad_n_kv_head, - pad_hd, - BLOCK_SIZE=BLOCK_SIZE, - BACKWARD_PASS=True, - ) - return dq.transpose(1, 2), dk.transpose(1, 2) - - -class LigerRopePaperFunction(torch.autograd.Function): - """ - Triton implementation of the orignal Rotary Positional Embedding (RoPE) operation from RoFormer. - - Please find the corresponding HuggingFace implementation here: - https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/roformer/modeling_roformer.py#L309 - - """ - - @staticmethod - def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - q size: (bsz, n_q_head, seq_len, head_dim) - k size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (1, seq_len, head_dim // 2) - sin size: (1, seq_len, head_dim // 2) - """ - q, k, cos, sin = rope_paper_forward(q, k, cos, sin) - ctx.save_for_backward(cos, sin) - return q, k - - def backward(ctx, dq, dk): - """ - dq size: (bsz, n_q_head, seq_len, head_dim) - dk size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (1, seq_len, head_dim // 2) - sin size: (1, seq_len, head_dim // 2) - """ - - cos, sin = ctx.saved_tensors - dq, dk = rope_paper_backward(dq, dk, cos, sin) - return dq, dk, None, None, None, None diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 160c74898..ffb8235cc 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -24,9 +24,6 @@ ) from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401 from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401 -from liger_kernel.transformers.rope_paper import ( # noqa: F401 - liger_rotary_paper_pos_emb, -) from liger_kernel.transformers.swiglu import ( # noqa: F401 LigerBlockSparseTop2MLP, LigerPhi3SwiGLUMLP, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index e776163f9..2478d4ba2 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -13,7 +13,6 @@ from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction -from liger_kernel.ops.rope_paper import LigerRopePaperFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction diff --git a/src/liger_kernel/transformers/rope_paper.py b/src/liger_kernel/transformers/rope_paper.py deleted file mode 100644 index 2362e004d..000000000 --- a/src/liger_kernel/transformers/rope_paper.py +++ /dev/null @@ -1,20 +0,0 @@ -from liger_kernel.ops.rope_paper import LigerRopePaperFunction - - -def liger_rotary_paper_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """ - Applies Rotary Positional Embedding (RoPE) operation to query and key states. - - Args: - q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). - k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). - cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). - sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). - position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. - unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. - """ - - return LigerRopePaperFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) diff --git a/test/transformers/test_rope_paper.py b/test/transformers/test_rope_paper.py deleted file mode 100644 index d2ae6f061..000000000 --- a/test/transformers/test_rope_paper.py +++ /dev/null @@ -1,163 +0,0 @@ -from test.utils import supports_bfloat16 - -import pytest -import torch -from transformers.models.roformer.modeling_roformer import ( - RoFormerSelfAttention, - RoFormerSinusoidalPositionalEmbedding, -) - -from liger_kernel.ops.rope_paper import LigerRopePaperFunction -from liger_kernel.transformers.functional import liger_rope_paper -from liger_kernel.transformers.rope_paper import liger_rotary_paper_pos_emb -from liger_kernel.utils import infer_device - -device = infer_device() - -SLEEP_SECONDS = 0.1 - -apply_rotary_pos_emb = RoFormerSelfAttention.apply_rotary_position_embeddings - - -@pytest.mark.parametrize( - "bsz, seq_len, num_q_heads, num_kv_heads, head_dim", - [ - (1, 128, 32, 32, 64), - (2, 128, 32, 32, 64), - # different q/k heads - (1, 128, 32, 8, 64), - (2, 128, 32, 8, 64), - # weird shapes - # HuggingFace llama/mistral source code doesn't support odd head dimension - # so we don't test it here - (3, 423, 73, 213, 92), - (3, 423, 73, 155, 92), - ], -) -@pytest.mark.parametrize( - "dtype, atol, rtol", - [ - (torch.float32, 1e-5, 1e-5), - pytest.param( - torch.bfloat16, - 1e-1, - 1e-5, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - ], -) -def test_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol -): - - _tensor_q = ( - torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device) - .transpose(1, 2) - .to(dtype) - ) - - _tensor_k = ( - torch.randn((bsz, seq_len, num_kv_heads, head_dim), device=device) - .transpose(1, 2) - .to(dtype) - ) - - q1 = _tensor_q.clone().requires_grad_(True) - k1 = _tensor_k.clone().requires_grad_(True) - - q2 = _tensor_q.clone().requires_grad_(True) - k2 = _tensor_k.clone().requires_grad_(True) - - rotary_emb = RoFormerSinusoidalPositionalEmbedding( - num_positions=seq_len, embedding_dim=head_dim - ).to(device) - sinusoidal_pos = rotary_emb((bsz, seq_len))[None, :, :].to(dtype) - - sin, cos = sinusoidal_pos.chunk(2, dim=-1) - # validate forward pass - hf_q, hf_k = apply_rotary_pos_emb(sinusoidal_pos[None, :, :, :], q1, k1) - tt_q, tt_k = liger_rotary_paper_pos_emb(q2, k2, cos, sin) - assert torch.allclose(hf_q, tt_q, atol=atol, rtol=rtol) - assert torch.allclose(hf_k, tt_k, atol=atol, rtol=rtol) - - # validate backward pass - dq, dk = ( - torch.randn_like(hf_q, device=device), - torch.randn_like(hf_k, device=device).to(dtype), - ) - - q1_grad, k1_grad = torch.autograd.grad( - (hf_q, hf_k), (q1, k1), (dq, dk), allow_unused=True - ) - q2_grad, k2_grad = torch.autograd.grad( - (tt_q, tt_k), (q2, k2), (dq.clone(), dk.clone()), allow_unused=True - ) - - assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) - assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize( - "bsz, seq_len, num_q_heads, num_kv_heads, head_dim", - [ - (1, 2, 2, 2, 8), - (1, 2, 1, 2, 8), - # weird shapes - (9, 7, 41, 41, 41), - ], -) -@pytest.mark.parametrize( - "dtype, atol, rtol", - [ - (torch.float32, 1e-5, 1e-5), - (torch.bfloat16, 1e-1, 1e-5), - ], -) -def test_functional_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol -): - _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) - _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) - - q1 = _q.clone().requires_grad_(True) - q2 = _q.clone().requires_grad_(True) - - k1 = _k.clone().requires_grad_(True) - k2 = _k.clone().requires_grad_(True) - - rotary_emb = RoFormerSinusoidalPositionalEmbedding( - num_positions=seq_len, embedding_dim=head_dim - ).to(device) - - sinusoidal_pos = rotary_emb((bsz, seq_len))[None, None, :, :].to(dtype) - - sin, cos = sinusoidal_pos.chunk(2, dim=-1) - functional_q, functional_k = liger_rope_paper(q1, k1, cos, sin) - class_q, class_k = LigerRopePaperFunction.apply(q2, k2, cos, sin) - - assert torch.allclose(functional_q, class_q, atol=atol, rtol=rtol) - assert torch.allclose(functional_k, class_k, atol=atol, rtol=rtol) - - dq, dk = torch.randn_like(functional_q), torch.randn_like(functional_k) - - dq1, dk1 = dq.clone(), dk.clone() - dq2, dk2 = dq.clone(), dk.clone() - - q1_grad, k1_grad = torch.autograd.grad( - (functional_q, functional_k), - (q1, k1), - (dq1, dk1), - allow_unused=True, - ) - - q2_grad, k2_grad = torch.autograd.grad( - (class_q, class_k), - (q2, k2), - (dq2, dk2), - allow_unused=True, - ) - - assert torch.allclose(q1_grad, q2_grad, atol=atol, rtol=rtol) - assert torch.allclose(k1_grad, k2_grad, atol=atol, rtol=rtol) From c259efdc4257f8570b7123b2b1f7334f0018e59a Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Wed, 11 Dec 2024 17:35:30 +0800 Subject: [PATCH 7/8] Cleanup --- src/liger_kernel/transformers/functional.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 2478d4ba2..66316056b 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -171,8 +171,6 @@ def liger_rope( return LigerRopeFunction.apply( q, k, cos, sin, position_ids, unsqueeze_dim, paper_form ) -def liger_rope_paper(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - return LigerRopePaperFunction.apply(q, k, cos, sin, position_ids, unsqueeze_dim) def liger_swiglu(a, b): From 85996d593080401014674781a887fadd7c5600ad Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Wang Date: Wed, 11 Dec 2024 17:58:54 +0800 Subject: [PATCH 8/8] Cleanup --- test/transformers/test_transformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/transformers/test_transformers.py b/test/transformers/test_transformers.py index 3e6d79123..9601229ec 100644 --- a/test/transformers/test_transformers.py +++ b/test/transformers/test_transformers.py @@ -12,7 +12,6 @@ def test_import_from_root(): LigerPhi3SwiGLUMLP, LigerRMSNorm, LigerSwiGLUMLP, - liger_rotary_paper_pos_emb, liger_rotary_pos_emb, ) except Exception: