Skip to content

Commit

Permalink
Feature/utils (#11)
Browse files Browse the repository at this point in the history
* migrate utils

* minor fix

* minor

* minor

* add utilities

* format

* play with tk
  • Loading branch information
xzyaoi authored Jul 16, 2024
1 parent ee8ca0e commit cf7dc90
Show file tree
Hide file tree
Showing 103 changed files with 10,971 additions and 173 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,5 @@ dmypy.json

# templates
.github/templates/*
.vscode/
.vscode/
.local
9 changes: 7 additions & 2 deletions benchmarks/bench_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
gen_batched_sparse_quant4_NT,
bmm_4bit_2_4,
)
from triteia.python.utils import timing_function, print_results_table, export_benchmark_results
from triteia.python.utils import (
timing_function,
print_results_table,
export_benchmark_results,
)
from triteia.python.configs.models.llama import llama_shapes

flops_func = lambda b, m, n, k: 2 * b * m * n * k
Expand Down Expand Up @@ -65,9 +69,10 @@ def w4_2_4_native(qweight, x, meta, scale):
print_results_table(f"bmm b={b},m={m},n={n},k={k}", results)
return results


if __name__ == "__main__":
results = []
results.append(benchmark(2, 256, 32, 256))
results.append(benchmark(8, 4096, 32, 4096))
results.append(benchmark(8, 8192, 8, 8192))
export_benchmark_results(results, ".local/bmm_bench.json")
export_benchmark_results(results, ".local/bmm_bench.json")
9 changes: 7 additions & 2 deletions benchmarks/bench_matmul.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
from triteia.python.ops import matmul_4bit_2_4, gen_sparse_quant4_NT
from triteia.python.utils import timing_function, print_results_table, export_benchmark_results
from triteia.python.utils import (
timing_function,
print_results_table,
export_benchmark_results,
)

flops_func = lambda m, n, k: 2 * m * n * k

Expand Down Expand Up @@ -41,11 +45,12 @@ def w4_2_4_func(qweight, x, meta, scale):
print_results_table(f"matmul m={m},n={n},k={k}", results)
return results


if __name__ == "__main__":
results = []
results.append(benchmark(256, 32, 256))
results.append(benchmark(256, 256, 256))
results.append(benchmark(4096, 4096, 4096))
results.append(benchmark(8192, 8192, 8192))
results.append(benchmark(16384, 128, 16384))
export_benchmark_results(results, ".local/matmul_bench.json")
export_benchmark_results(results, ".local/matmul_bench.json")
14 changes: 10 additions & 4 deletions benchmarks/bench_sbmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
sbmm_4bit_2_4_native,
sbmm_4bit_2_4_multilaunch,
)
from triteia.python.utils import timing_function, print_results_table, export_benchmark_results
from triteia.python.utils import (
timing_function,
print_results_table,
export_benchmark_results,
)
from triteia.python.configs.models.llama import llama_shapes
from triteia.python.ops.utils.generator import generate_model_distribution
from triteia.python.ops import gen_batched_sparse_quant4_NT
Expand Down Expand Up @@ -112,7 +116,7 @@ def w4_2_4_multilaunch_func(qweight, x, meta, scale, indices):
[2, 4, 8, 16, 32, 64],
[2, 4, 8, 16, 32, 64, 128],
]
distributions = ['uniform', 'zipf:1.5']
distributions = ["uniform", "zipf:1.5"]
ms = [4096, 8192]
ns = [4096, 8192]
for distribution in distributions:
Expand All @@ -125,6 +129,8 @@ def w4_2_4_multilaunch_func(qweight, x, meta, scale, indices):
benchmark(distribution, nr[i], nm[i][j], m, n)
)
except Exception as e:
print(f"Failed to benchmark sbmm nr={nr[i]},nm={nm[i][j]},m={m},n={n}")
print(
f"Failed to benchmark sbmm nr={nr[i]},nm={nm[i][j]},m={m},n={n}"
)
print(e)
export_benchmark_results(results, ".local/sbmm_bench.json")
export_benchmark_results(results, ".local/sbmm_bench.json")
52 changes: 52 additions & 0 deletions docs/examples/04_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
from triteia.python.ops import gen_sparse_quant4_NT, matmul_4bit_2_4
from triteia.python.ops.utils.generator import torch_weight_to_sparse_marlin

dev = "cuda"
n = 1
m = 12288
k = 6144
groupsize = -1
tp_size = 8

x = torch.randn((n, k), dtype=torch.float16, device=dev)
weight_ref, qweight, scale, meta = gen_sparse_quant4_NT(
m, k, groupsize=groupsize, device=dev
)
print(
f"weight_ref: {weight_ref.shape}, qweight: {qweight.shape}, scale: {scale.shape}, meta: {meta.shape}"
)
fp16_output = torch.matmul(x, weight_ref)
qs_output = matmul_4bit_2_4(qweight, x, meta, scale)

qweights_by_tp, scales_by_tp, metas_by_tp = torch_weight_to_sparse_marlin(
weight_ref, scale, tp_size=tp_size, chunk_by="column"
)
partial_outputs = []
partial_fp16_outputs = []
for i in range(tp_size):
tp_weight = weight_ref[:, i * m // tp_size : (i + 1) * m // tp_size].contiguous()
tp_scales = scale[:, i * m // tp_size : (i + 1) * m // tp_size].contiguous()
partial_output = matmul_4bit_2_4(
qweights_by_tp[i], x, metas_by_tp[i], scales_by_tp[i]
)

partial_outputs.append(partial_output)
partial_fp16_output = torch.matmul(x, tp_weight)
partial_fp16_outputs.append(partial_fp16_output)

torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
tp_output = torch.cat(partial_outputs, dim=1)
fp16_merged_output = torch.cat(partial_fp16_outputs, dim=1)

print(f"max diff (quant): {torch.max(torch.abs(fp16_output - qs_output))}")
print(
f"mean diff (tp): {torch.max(torch.abs(tp_output - fp16_output)/torch.mean(torch.abs(fp16_output)))}"
)
print(tp_output - qs_output)
print(fp16_output - fp16_merged_output)
print(
f"mean diff (fp16): {torch.mean(torch.abs(fp16_output - fp16_merged_output)/torch.mean(torch.abs(fp16_output)))}"
)
# print(f"\n\n{tp_output}\n{qs_output}")
20 changes: 17 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@
from setuptools import find_packages, setup
from torch.utils import cpp_extension


def get_compute_capability():
try:
compute_cap = os.popen("nvidia-smi --query-gpu=compute_cap --format=csv,noheader").read().strip().split("\n")[0]
compute_cap = (
os.popen("nvidia-smi --query-gpu=compute_cap --format=csv,noheader")
.read()
.strip()
.split("\n")[0]
)
major, minor = compute_cap.split(".")
return f"{major}{minor}"
except Exception as e:
print(f"Failed to detect compute capability: {e}")
return None


def read(*paths, **kwargs):
"""Read the contents of a text file safely.
>>> read("triteia", "VERSION")
Expand All @@ -36,7 +43,8 @@ def read_requirements(path):
for line in read(path).split("\n")
if not line.startswith(('"', "#", "-", "git+"))
]



compute_cap = get_compute_capability()
if compute_cap is None:
raise ValueError("Failed to detect compute capability")
Expand All @@ -63,7 +71,13 @@ def read_requirements(path):
],
dlink=True,
extra_compile_args={
"nvcc": ["-O3", f"-arch=sm_{compute_cap}", "--ptxas-options=-v", "-dc", "-lineinfo"]
"nvcc": [
"-O3",
f"-arch=sm_{compute_cap}",
"--ptxas-options=-v",
"-dc",
"-lineinfo",
]
},
extra_link_args=["-lcudadevrt", "-lcudart"],
),
Expand Down
1 change: 0 additions & 1 deletion tests/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_tiny(self):
# self.run_problem(256, 16, 256, groupsize=-1)
# self.run_problem(256, 16, 512, groupsize=-1)
self.run_problem(5504, 5504, 5504, groupsize=-1)


# def test_llama(self):
# bsz = 16
Expand Down
3 changes: 2 additions & 1 deletion tests/ops/test_sbmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run_problem(
def test_tiny(self):
self.run_problem("uniform", 10, 5, 256, 256)
self.run_problem("zipf:1.5", 128, 2, 4096, 12288)

# def test_llama(self):
# nrs = [16, 32, 64, 128, 256]
# nms = [[2,4,8,16], [2,4,8,16,32], [2,4,8,16,32,64], [2,4,8,16,32,64,128], [2,4,8,16,32,64,128,256]]
Expand All @@ -79,5 +79,6 @@ def test_tiny(self):
# for distribution in distributions:
# self.run_problem(distribution, nr, nm, layer[0], layer[1])


if __name__ == "__main__":
unittest.main()
29 changes: 19 additions & 10 deletions tests/ops/test_sbmm_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,34 @@ def run_problem_column(
qweights.append(qweight)
scales.append(scale)
metas.append(meta)

fp16_partial_output = sbmm_16bit_forloop(weight_ref, x, indices, base_weight=None)

fp16_partial_output = sbmm_16bit_forloop(
weight_ref, x, indices, base_weight=None
)
native_partial_output = sbmm_4bit_2_4_native(
qweight, x, meta, scale, indices, base_weight=None
)
ref_fp16_outputs.append(fp16_partial_output)
outputs.append(native_partial_output)

ref_fp16_final_outputs = torch.cat(ref_fp16_outputs, dim=1)
final_outputs = torch.cat(outputs, dim=1)

stacked_fp16_weights = torch.cat(ref_weights, dim=2)
stacked_qweights = torch.cat(qweights, dim=2)
stacked_scales = torch.cat(scales, dim=2)
stacked_metas = torch.cat(metas, dim=1)

stacked_fp16_output = sbmm_16bit_forloop(stacked_fp16_weights, x, indices, base_weight=None)

stacked_fp16_output = sbmm_16bit_forloop(
stacked_fp16_weights, x, indices, base_weight=None
)
stacked_native_output = sbmm_4bit_2_4_native(
stacked_qweights, x, stacked_metas, stacked_scales, indices, base_weight=None
stacked_qweights,
x,
stacked_metas,
stacked_scales,
indices,
base_weight=None,
)
self.assertLess(
torch.mean(torch.abs(final_outputs - ref_fp16_final_outputs))
Expand All @@ -72,15 +81,15 @@ def run_problem_column(
/ torch.mean(torch.abs(ref_fp16_final_outputs)),
0.002,
)

except torch.cuda.OutOfMemoryError as e:
print(f"Out of memory, skipping nr={nr}, nm={nm}, m={m}, k={k}")
finally:
torch.cuda.empty_cache()

def test_tiny(self):
self.run_problem_column("uniform", 10, 5, 256, 256, 2)
self.run_problem_column("uniform", 10, 5, 256, 256, 2)


if __name__ == "__main__":
unittest.main()
Empty file added tests/utils/test_sparsity.py
Empty file.
3 changes: 3 additions & 0 deletions triteia/csrc/flash_kittens/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#

> This is a fork of [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)
Loading

0 comments on commit cf7dc90

Please sign in to comment.