Skip to content

Commit

Permalink
benchmark infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Jul 3, 2024
1 parent bbc55b7 commit 55bdfa7
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 243 deletions.
12 changes: 0 additions & 12 deletions .github/FUNDING.yml

This file was deleted.

68 changes: 0 additions & 68 deletions .github/init.sh

This file was deleted.

3 changes: 0 additions & 3 deletions .github/release_message.sh

This file was deleted.

36 changes: 0 additions & 36 deletions .github/rename_project.sh

This file was deleted.

25 changes: 0 additions & 25 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
@@ -1,25 +0,0 @@
name: CI

on:
push:
branches: [ version/v2 ]
pull_request:
branches: [ version/v2 ]
workflow_dispatch:

jobs:
tests_linux:
strategy:
fail-fast: false
runs-on: [self-hosted, linux]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: make install
- name: Run tests
run: make test
- name: "Upload coverage to Codecov"
uses: codecov/codecov-action@v3
50 changes: 0 additions & 50 deletions .github/workflows/release.yml

This file was deleted.

42 changes: 0 additions & 42 deletions .github/workflows/rename_project.yml

This file was deleted.

27 changes: 27 additions & 0 deletions benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from triteia.python.ops import matmul_4bit_2_4, gen_sparse_quant4_NT
from triteia.python.utils import timing_function, print_results_table

flops_func = lambda m, n, k: 2 * m * n * k

def benchmark(m,n,k, dev="cuda", groupsize=-1):
x = torch.randn((n, k), dtype=torch.float16, device=dev)
weight_ref, qweight, scale, meta = gen_sparse_quant4_NT(
m, k, groupsize=groupsize, device=dev
)
def fp16_func(x, weight_ref):
return torch.matmul(x, weight_ref)
def w4_2_4_func(qweight, x, meta, scale):
return matmul_4bit_2_4(qweight, x, meta, scale)
fp16_result = timing_function(
fp16_func, flops_func, kwargs={"m": m, "n": n, "k": k, "x": x, "weight_ref": weight_ref}
)
w4_2_4_result = timing_function(
w4_2_4_func, flops_func, kwargs={"m": m, "n": n, "k": k, "qweight": qweight, "x": x, "meta": meta, "scale": scale}
)
results = [fp16_result, w4_2_4_result]

print_results_table("matmul_4bit_2_4", results)

if __name__ == "__main__":
benchmark(4096*2, 32, 4096*2)
7 changes: 2 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
# This template is a low-dependency template.
# By default there is no requirements added here.
# Add the requirements you need to this file.
# or run `make init` to create this file automatically based on the template.
# You can also run `make switch-to-poetry` to use the poetry package manager.
pynvml
rich
19 changes: 17 additions & 2 deletions triteia/python/configs/gpus/specs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from pynvml import *
nvmlInit()

nvidia_rtx_3090 = {
'name': 'NVIDIA RTX 3090',
'name': 'NVIDIA GeForce RTX 3090',
'compute_capability': '8.6',
'memory': 24, # in GB
'bandwidth': 936.2,
'fp16_tflops': 35.58,
'fp32_tflops': 35.58,
}

nvidia_gpus = [nvidia_rtx_3090]
nvidia_gpus = [nvidia_rtx_3090]

def get_gpu_device_info():
deviceCount = nvmlDeviceGetCount()
name = None
for i in range(deviceCount):
handle = nvmlDeviceGetHandleByIndex(i)
name = nvmlDeviceGetName(handle)
break
for gpu in nvidia_gpus:
if gpu['name'] == name:
return gpu
return None
3 changes: 3 additions & 0 deletions triteia/python/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .benchmark import timing_function, print_results_table

__all__ = ["timing_function", "print_results_table"]
54 changes: 54 additions & 0 deletions triteia/python/utils/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import inspect
import torch
from rich.console import Console
from rich.table import Table
from triteia.python.configs.gpus.specs import get_gpu_device_info

def timing_function(func, flops_func, kwargs):
func_args_names = inspect.getfullargspec(func).args
func_args = {arg: kwargs[arg] for arg in func_args_names if arg in kwargs}
gpu_info = get_gpu_device_info()
if flops_func:
flops_func_args_names = inspect.getfullargspec(flops_func).args
flops_func_args = {arg: kwargs[arg] for arg in flops_func_args_names if arg in kwargs}

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
output = func(**func_args)
end.record()
torch.cuda.synchronize()
elapsed = start.elapsed_time(end)

if flops_func:
total_flops = flops_func(**flops_func_args) # FLOPS
perf_flops = total_flops/elapsed # FlOPS/ms
if gpu_info:
mfu = 100 * perf_flops/1e9/gpu_info["fp16_tflops"]
return {
"output": output,
"elapsed": elapsed, # ms
"func_name": func.__name__,
"total_flops": total_flops/1e9 if flops_func else None, # GFLOPS
"perf_flops": perf_flops/1e6 if flops_func else None, # GFLOPS/s
"mfu": mfu if flops_func and gpu_info else None,
"args": kwargs
}

def print_results_table(title, results):
table = Table(title = title)
table.add_column("Func Name")
table.add_column("Elapsed (ms)")
table.add_column("Total FLOPS (GFLOPS)")
table.add_column("Perf FLOPS (GFLOPS/s)")
table.add_column("MFU (%)")
for result in results:
table.add_row(
result["func_name"],
f"{result['elapsed']:.2f}",
f"{result['total_flops']:.2f}" if result["total_flops"] else None,
f"{result['perf_flops']:.2f}" if result["perf_flops"] else None,
f"{result['mfu']:.2f}" if result["mfu"] else None
)
console = Console()
console.print(table)

0 comments on commit 55bdfa7

Please sign in to comment.