Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 authored Dec 11, 2024
2 parents 36e483c + 78e8a85 commit 7d96535
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 46 deletions.
9 changes: 5 additions & 4 deletions dev/modal/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import modal

ROOT_PATH = Path(__file__).parent.parent.parent
REMOTE_ROOT_PATH = "/root/liger-kernel"

# REBUILD_IMAGE is an environment variable that is set to "true" in the nightly build
REBUILD_IMAGE = os.getenv("REBUILD_IMAGE") is not None
Expand All @@ -17,13 +18,13 @@
app = modal.App("liger_tests", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
def liger_tests():
import subprocess

subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel")
subprocess.run(["pip", "install", "-e", "."], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test"], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test-convergence"], check=True, cwd=REMOTE_ROOT_PATH)
9 changes: 5 additions & 4 deletions dev/modal/tests_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import modal

ROOT_PATH = Path(__file__).parent.parent.parent
REMOTE_ROOT_PATH = "/root/liger-kernel"

# REBUILD_IMAGE is an environment variable that is set to "true" in the nightly build
REBUILD_IMAGE = os.getenv("REBUILD_IMAGE") is not None
Expand All @@ -22,13 +23,13 @@
app = modal.App("liger_tests", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10)
def liger_tests_bwd():
import subprocess

subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel")
subprocess.run(["pip", "install", "-e", "."], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test"], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test-convergence"], check=True, cwd=REMOTE_ROOT_PATH)
11 changes: 2 additions & 9 deletions examples/alignment/run_orpo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ORPOConfig, ORPOTrainer # noqa: F401
from trl import ORPOConfig # noqa: F401

from liger_kernel.transformers import LigerORPOTrainer # noqa: F401
from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
Expand All @@ -19,13 +19,6 @@

train_dataset = load_dataset("trl-lib/tldr-preference", split="train")

# train_dataset = train_dataset.map(
# lambda example: {
# "prompt": example["prompt"],
# "chosen": example["chosen"][0]["content"],
# "rejected": example["rejected"][0]["content"],
# }
# )
training_args = ORPOConfig(
output_dir="Llama3.2_1B_Instruct",
beta=0.1,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "liger_kernel"
version = "0.5.0"
version = "0.5.1"
description = "Efficient Triton kernels for LLM Training"
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
readme = { file = "README.md", content-type = "text/markdown" }
Expand Down
25 changes: 13 additions & 12 deletions src/liger_kernel/ops/qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def _triton_qwen2vl_mrope(
cos,
sin,
sl,
bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
Expand Down Expand Up @@ -41,13 +42,12 @@ def _triton_qwen2vl_mrope(
t_end = mrope_section_t
h_end = t_end + mrope_section_h

cos_row_idx = pid % sl
t_cos = cos + cos_row_idx * hd
h_cos = t_cos + sl * hd
w_cos = h_cos + sl * hd
t_sin = sin + cos_row_idx * hd
h_sin = t_sin + sl * hd
w_sin = h_sin + sl * hd
t_cos = cos + pid * hd
h_cos = t_cos + bs * sl * hd
w_cos = h_cos + bs * sl * hd
t_sin = sin + pid * hd
h_sin = t_sin + bs * sl * hd
w_sin = h_sin + bs * sl * hd

cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
Expand Down Expand Up @@ -151,6 +151,7 @@ def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -189,6 +190,7 @@ def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
cos,
sin,
seq_len,
batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -216,8 +218,8 @@ def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, 1, seq_len, head_dim)
sin size: (3, 1, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
ctx.save_for_backward(cos, sin)
Expand All @@ -228,10 +230,9 @@ def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (3, 1, seq_len, head_dim)
sin size: (3, 1, seq_len, head_dim)
cos size: (3, bsz, seq_len, head_dim)
sin size: (3, bsz, seq_len, head_dim)
"""

cos, sin = ctx.saved_tensors
mrope_section = ctx.mrope_section
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
Expand Down
1 change: 0 additions & 1 deletion src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)
from liger_kernel.transformers.orpo_trainer import LigerORPOTrainer # noqa: F401
from liger_kernel.transformers.rms_norm import LigerRMSNorm # noqa: F401
from liger_kernel.transformers.rope import liger_rotary_pos_emb # noqa: F401
from liger_kernel.transformers.swiglu import ( # noqa: F401
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/transformers/qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
Expand Down
6 changes: 6 additions & 0 deletions src/liger_kernel/transformers/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
try:
from liger_kernel.transformers.trainer.orpo_trainer import ( # noqa: F401
LigerORPOTrainer,
)
except ImportError:
raise ImportError("Please `pip install trl` to use LigerORPOTrainer")
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def concatenated_forward(
padding_value=self.padding_value,
device=self.accelerator.device,
)
# if self.accelerator.is_main_process:
# import pdb; pdb.set_trace()
# torch.distributed.barrier()

model_kwargs = (
{
"decoder_input_ids": self._shift_right(
Expand Down
13 changes: 9 additions & 4 deletions test/convergence/test_mini_models_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import pytest
import torch
import transformers
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerFast

Expand Down Expand Up @@ -378,8 +380,9 @@ def run_mini_model_multimodal(
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__) >= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
),
),
pytest.param(
Expand All @@ -398,8 +401,10 @@ def run_mini_model_multimodal(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__)
>= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
),
],
),
Expand Down
13 changes: 9 additions & 4 deletions test/convergence/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import pytest
import torch
import transformers
from datasets import load_from_disk
from packaging import version
from torch.utils.data import DataLoader
from transformers.models.gemma import GemmaConfig, GemmaForCausalLM
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
Expand Down Expand Up @@ -538,8 +540,9 @@ def run_mini_model(
5e-3,
1e-5,
marks=pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__) >= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
),
),
pytest.param(
Expand All @@ -558,8 +561,10 @@ def run_mini_model(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
pytest.mark.skipif(
not QWEN2_VL_AVAILABLE,
reason="Qwen2-VL not available in this version of transformers",
not QWEN2_VL_AVAILABLE
or version.parse(transformers.__version__)
>= version.parse("4.47.0"),
reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0",
),
],
),
Expand Down
8 changes: 6 additions & 2 deletions test/transformers/test_qwen2vl_mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def test_correctness(
k2 = _tensor_k.clone().requires_grad_(True)

# NOTE: this position ids distribution is different from the real one, just to test op correctness
pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
3, bsz, seq_len
)
cos, sin = rotary_emb(k1, pos_ids)

# validate forward pass
Expand Down Expand Up @@ -130,7 +132,9 @@ def test_functional_correctness(

rotary_emb = Qwen2VLRotaryEmbedding(head_dim, device=device)

pos_ids = torch.arange(seq_len * 3, device=device, dtype=torch.long).view(3, 1, -1)
pos_ids = torch.arange(seq_len * 3 * bsz, device=device, dtype=torch.long).view(
3, bsz, seq_len
)
cos, sin = rotary_emb(k1, pos_ids)

functional_q, functional_k = liger_qwen2vl_mrope(q1, k1, cos, sin, mrope_section)
Expand Down

0 comments on commit 7d96535

Please sign in to comment.