Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Jul 3, 2024
1 parent 55bdfa7 commit e8b4be6
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
17 changes: 9 additions & 8 deletions benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ def fp16_func(x, weight_ref):
return torch.matmul(x, weight_ref)
def w4_2_4_func(qweight, x, meta, scale):
return matmul_4bit_2_4(qweight, x, meta, scale)
fp16_result = timing_function(
fp16_func, flops_func, kwargs={"m": m, "n": n, "k": k, "x": x, "weight_ref": weight_ref}
)

w4_2_4_result = timing_function(
w4_2_4_func, flops_func, kwargs={"m": m, "n": n, "k": k, "qweight": qweight, "x": x, "meta": meta, "scale": scale}
w4_2_4_func, flops_func, kwargs={"m": m, "n": n, "k": k, "qweight": qweight, "x": x, "meta": meta, "scale": scale}, repeats=5
)
fp16_result = timing_function(
fp16_func, flops_func, kwargs={"m": m, "n": n, "k": k, "x": x, "weight_ref": weight_ref}, repeats=5
)
results = [fp16_result, w4_2_4_result]

print_results_table("matmul_4bit_2_4", results)

print_results_table(f"matmul m={m},n={n},k={k}", results)

if __name__ == "__main__":
benchmark(4096*2, 32, 4096*2)
benchmark(256, 32, 256)
benchmark(4096, 32, 4096)
22 changes: 14 additions & 8 deletions triteia/python/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,33 @@
from rich.table import Table
from triteia.python.configs.gpus.specs import get_gpu_device_info

def timing_function(func, flops_func, kwargs):
def timing_function(func, flops_func, kwargs, repeats=1):
func_args_names = inspect.getfullargspec(func).args
func_args = {arg: kwargs[arg] for arg in func_args_names if arg in kwargs}
gpu_info = get_gpu_device_info()
if flops_func:
flops_func_args_names = inspect.getfullargspec(flops_func).args
flops_func_args = {arg: kwargs[arg] for arg in flops_func_args_names if arg in kwargs}
elapseds = []

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
output = func(**func_args)
end.record()
torch.cuda.synchronize()
elapsed = start.elapsed_time(end)
for i in range(repeats):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
output = func(**func_args)
end.record()
torch.cuda.synchronize()
elapsed = start.elapsed_time(end)
elapseds.append(elapsed)

elapsed = sum(elapseds)/repeats

if flops_func:
total_flops = flops_func(**flops_func_args) # FLOPS
perf_flops = total_flops/elapsed # FlOPS/ms
if gpu_info:
mfu = 100 * perf_flops/1e9/gpu_info["fp16_tflops"]

return {
"output": output,
"elapsed": elapsed, # ms
Expand Down

0 comments on commit e8b4be6

Please sign in to comment.