Skip to content

Commit

Permalink
Add Benchmark Result
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 committed Dec 13, 2024
1 parent 6e67869 commit 3a76c76
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 0 deletions.
30 changes: 30 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Benchmarking Liger Kernels

Follow these steps to benchmark and visualize kernel performance:

1. Create a benchmark script
- Add your script under `benchmark/scripts/`
- Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)

2. Run the benchmark
- Results will be saved to `benchmark/data/all_benchmark_data.csv`

Example: Benchmarking KTO Loss
```bash
cd benchmark
python scripts/benchmark_kto_loss.py
```

3. Visualize results
- Use the visualization script with appropriate parameters

Example: Visualizing KTO Loss benchmark results
```bash
python benchmarks_visualizer.py \
--kernel-name kto_loss \
--metric-name memory \
--kernel-operation-mode full
```

4. View results
- Generated plots will be saved in `benchmark/visualizations/`
30 changes: 30 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,8.2532958984375,8.235372543334961,8.274937629699707,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,16.888959884643555,16.879615783691406,16.898893356323242,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,32.13854217529297,32.12795639038086,32.149131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,64.81161499023438,64.81161499023438,64.81161499023438,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,128.68646240234375,128.68646240234375,128.68646240234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.146656036376953,7.143622398376465,7.152345657348633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,12.538240432739258,12.521356582641602,12.540371894836426,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,26.29542350769043,25.303590774536133,26.88591957092285,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,49.26508712768555,49.26508712768555,49.26508712768555,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,98.9525146484375,98.9525146484375,98.9525146484375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,9.005151748657227,8.97766399383545,9.046483039855957,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,19.108863830566406,19.09713363647461,19.185260772705078,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.80137634277344,32.775360107421875,32.827388763427734,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,65.46678161621094,65.46678161621094,65.46678161621094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,129.91734313964844,129.91734313964844,129.91734313964844,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,16.091487884521484,14.86076831817627,16.23084831237793,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,28.04204750061035,28.03957176208496,28.055641174316406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,54.70073699951172,54.70073699951172,54.70073699951172,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,108.09929656982422,108.09929656982422,108.09929656982422,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,215.1945343017578,215.1945343017578,215.1945343017578,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,3037.75390625,3037.75390625,3037.75390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3800.0126953125,3800.0126953125,3800.0126953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,4565.28076171875,4565.28076171875,4565.28076171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,4589.31787109375,4589.31787109375,4589.31787109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,4637.39208984375,4637.39208984375,4637.39208984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4793.7626953125,4793.7626953125,4793.7626953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6551.2978515625,6551.2978515625,6551.2978515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,10063.3681640625,10063.3681640625,10063.3681640625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,17093.5078125,17093.5078125,17093.5078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,31153.7890625,31153.7890625,31153.7890625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
264 changes: 264 additions & 0 deletions benchmark/scripts/benchmark_kto_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
from liger_kernel.utils import infer_device

device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchKTOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
from test.chunked_loss.test_kto_loss import HFKTOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.kto_loss = HFKTOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
).get_batch_loss_metrics

def forward(self, x, ref_x, y):
return self.kto_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)[0]


class LigerKTOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.kto_loss = LigerFusedLinearKTOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
)

def forward(self, x, ref_x, y):
return self.kto_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)[0]


def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_kto_loss = TorchKTOLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)

liger_kto_loss = LigerKTOLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

def fwd():
if provider == "liger":
return liger_kto_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_kto_loss(_input, ref_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_kto_loss = TorchKTOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_kto_loss = LigerKTOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

def fwd():
if provider == "liger":
return liger_kto_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_kto_loss(_input, ref_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "kto_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_kto_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)

run_benchmarks(
bench_test_fn=bench_memory_kto_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)

0 comments on commit 3a76c76

Please sign in to comment.