Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix All chunked_loss Benchmark Scripts #438

Merged
merged 3 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 16 additions & 34 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,6 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadCPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.

:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.cpo_loss = HFCPOLoss().get_batch_loss_metrics

def forward(self, x, y):
return self.cpo_loss(x, self.lin.weight, y)


class LigerLMHeadCPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.cpo_loss = LigerFusedLinearCPOFunction.apply

def forward(self, x, y):
return self.cpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
Expand All @@ -57,15 +27,21 @@ def forward(self, x, y):
def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -96,6 +72,8 @@ def full():
def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -104,8 +82,12 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_cpo = TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_cpo = LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
84 changes: 26 additions & 58 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import sys

import torch
import triton

Expand All @@ -14,58 +17,12 @@

device = infer_device()


class TorchDPOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
beta: float = 0.1,
ignore_index: int = -100,
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.dpo_loss = HF_DPO_Loss(beta=beta, ignore_index=ignore_index)

def forward(self, x, target):
return self.dpo_loss.get_batch_loss_metrics(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
)


class LigerDPOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
beta: float = 0.1,
ignore_index: int = -100,
bias: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.beta = beta
self.ignore_index = ignore_index

def forward(self, x, target):
return LigerFusedLinearDPOFunction.apply(
x,
self.lin.weight,
target,
self.lin.bias if hasattr(self.lin, "bias") else None,
self.ignore_index,
self.beta,
True,
)
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -76,11 +33,16 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
torch_dpo_loss = lambda x, ref_x, target: TorchLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]
liger_dpo_loss = lambda x, ref_x, target: LigerLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

Expand All @@ -91,9 +53,9 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
return liger_dpo_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)
return torch_dpo_loss(_input, ref_input, target)

def full():
y = fwd()
Expand All @@ -108,6 +70,8 @@ def full():


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -119,12 +83,16 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_dpo_loss = TorchDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
liger_dpo_loss = LigerDPOLoss(H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias).to(device)
torch_dpo_loss = lambda x, ref_x, target: TorchLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]
liger_dpo_loss = lambda x, ref_x, target: LigerLMHeadDPO(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)(x, ref_x, target)[0]

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

ref_input = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=False)
# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

Expand All @@ -135,9 +103,9 @@ def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu

def fwd():
if provider == "liger":
return liger_dpo_loss(_input, target)
return liger_dpo_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_dpo_loss(_input, target)
return torch_dpo_loss(_input, ref_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
50 changes: 16 additions & 34 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,6 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadORPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.

:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_orpo_loss import HF_ORPO_Loss

super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.orpo_loss = HF_ORPO_Loss().get_batch_loss_metrics

def forward(self, x, y):
return self.orpo_loss(x, self.lin.weight, y)


class LigerLMHeadORPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.orpo_loss = LigerFusedLinearORPOFunction.apply

def forward(self, x, y):
return self.orpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
Expand All @@ -57,15 +27,21 @@ def forward(self, x, y):
def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -96,6 +72,8 @@ def full():
def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -104,8 +82,12 @@ def bench_speed_fused_linear_orpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_orpo = TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_orpo = LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
50 changes: 16 additions & 34 deletions benchmark/scripts/benchmark_simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,6 @@
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadSimPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.

:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics

def forward(self, x, y):
return self.simpo_loss(x, self.lin.weight, y)


class LigerLMHeadSimPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype)
self.simpo_loss = LigerFusedLinearSimPOFunction.apply

def forward(self, x, y):
return self.simpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
Expand All @@ -57,15 +27,21 @@ def forward(self, x, y):
def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -96,6 +72,8 @@ def full():
def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
Expand All @@ -104,8 +82,12 @@ def bench_speed_fused_linear_simpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down