Skip to content

Commit

Permalink
Merge pull request #1 from eth-easl/feature/bmm
Browse files Browse the repository at this point in the history
Feature/bmm
  • Loading branch information
xzyaoi authored Jul 2, 2024
2 parents cac6160 + 2da685f commit bbc55b7
Show file tree
Hide file tree
Showing 20 changed files with 1,839 additions and 728 deletions.
23 changes: 23 additions & 0 deletions docs/examples/01_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from triteia.python.ops import gen_sparse_quant4_NT, matmul_4bit_2_4

dev = "cuda"
n=1
m=256
k=512
groupsize = -1

# x (1, 512)
# weight_ref (512, 256)
# qweight (16, 512) --> (512, 256)
x = torch.randn((n, k), dtype=torch.float16, device=dev)
weight_ref, qweight, scale, meta = gen_sparse_quant4_NT(
m, k, groupsize=groupsize, device=dev
)
# weight_ref = weight_ref.permute(0, 2, 1)
fp16_output = torch.matmul(x, weight_ref)
qs_output = matmul_4bit_2_4(qweight, x, meta, scale)
print(f"weight_ref: {weight_ref.shape}, qweight: {qweight.shape}, scale: {scale.shape}, meta: {meta.shape}")
print(fp16_output)
print(qs_output)
torch.cuda.synchronize()
27 changes: 27 additions & 0 deletions docs/examples/02_bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from triteia.python.ops import (
bmm_4bit_2_4,
bmm_4bit_2_4_forloop,
gen_batched_sparse_quant4_NT
)

dev = "cuda"
b=16
n=1
m=256
p=512
groupsize = -1

x = torch.randn((b,1, p), dtype=torch.float16, device=dev)
weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT(
b, m, p, groupsize=groupsize, device=dev
)
# weight_ref = weight_ref.permute(0, 2, 1)
fp16_output = torch.bmm(x, weight_ref)
forloop_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale)
native_output = bmm_4bit_2_4(qweight, x, meta, scale)
print(fp16_output)
print(forloop_output)
print(native_output)
print(f"native_output: {native_output.shape}, fp16_output: {fp16_output.shape}, forloop_output: {forloop_output.shape}")
torch.cuda.synchronize()
11 changes: 9 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def read(*paths, **kwargs):
content = open_file.read().strip()
return content


def read_requirements(path):
return [
line.strip()
for line in read(path).split("\n")
if not line.startswith(('"', "#", "-", "git+"))
]


setup(
name="triteia",
version=read("triteia", "VERSION"),
Expand All @@ -41,14 +43,19 @@ def read_requirements(path):
extras_require={"test": read_requirements("requirements-dev.txt")},
ext_modules=[
cpp_extension.CUDAExtension(
"marlin_cuda",
"triteia_cuda",
[
"triteia/csrc/ops/ops.cpp",
"triteia/csrc/ops/marlin_nm.cu",
"triteia/csrc/ops/triteia_nm_bmm.cu",
],
dlink=True,
extra_compile_args={
"nvcc": ["-O3", "-arch=sm_86", "--ptxas-options=-v", "-lineinfo"]
"nvcc": [
"-O3", "-arch=sm_86", "--ptxas-options=-v", "-dc", "-lineinfo"
]
},
extra_link_args=["-lcudadevrt","-lcudart"],
),
],
cmdclass={"build_ext": cpp_extension.BuildExtension},
Expand Down
54 changes: 54 additions & 0 deletions tests/ops/test_bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import unittest
from triteia.python.ops import (
bmm_4bit_2_4_forloop,
gen_batched_sparse_quant4_NT,
bmm_4bit_2_4,
)
from triteia.python.configs.models.llama import llama_shapes

class TestMatmulOp(unittest.TestCase):
def run_problem(self, b: int, m: int, n: int, k: int, groupsize=-1, dev="cuda"):
try:
print(f"Running bmm problem with b={b} m={m}, n={n}, k={k}")
x = torch.randn((b, 1, k), dtype=torch.float16, device=dev)
weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT(
b, m, k, groupsize=groupsize, device=dev
)
fp16_output = torch.matmul(x, weight_ref)
forloop_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale)
native_output = bmm_4bit_2_4(qweight, x, meta, scale)
torch.cuda.synchronize()
self.assertLess(
torch.mean(torch.abs(forloop_output - fp16_output))
/ torch.mean(torch.abs(fp16_output)),
0.002,
)
self.assertLess(
torch.mean(torch.abs(native_output - fp16_output))
/ torch.mean(torch.abs(fp16_output)),
0.002,
)
except torch.cuda.OutOfMemoryError as e:
print(f"Out of memory, skipping b={b} m={m}, n={n}, k={k}")

def test_tiny(self):
self.run_problem(16, 256, 16, 256, groupsize=-1)
self.run_problem(16, 512, 16, 512, groupsize=-1)
self.run_problem(16, 256, 16, 512, groupsize=-1)
self.run_problem(16, 512, 16, 256, groupsize=-1)
self.run_problem(8, 256, 16, 256, groupsize=-1)
self.run_problem(4, 512, 16, 512, groupsize=-1)
self.run_problem(4, 256, 16, 512, groupsize=-1)
self.run_problem(8, 512, 16, 256, groupsize=-1)

def test_llama(self):
bszs = [4, 8, 16]
for _, layers in llama_shapes.items():
for layer in layers:
for bsz in bszs:
self.run_problem(bsz, layer[1], 16, layer[0])


if __name__ == "__main__":
unittest.main()
10 changes: 6 additions & 4 deletions tests/ops/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
import unittest
from triteia.python.ops import matmul_4bit_2_4, gen_quant4_NT
from triteia.python.ops import matmul_4bit_2_4, gen_sparse_quant4_NT
from triteia.python.configs.models.llama import llama_shapes

class TestMatmulOp(unittest.TestCase):
def run_problem(self, m: int, n: int, k: int, groupsize=-1, dev="cuda"):
try:
print(f"Running problem with m={m}, n={n}, k={k}")
print(f"Running mm problem with m={m}, n={n}, k={k}")
x = torch.randn((n, k), dtype=torch.float16, device=dev)
weight_ref, qweight, scale, meta = gen_quant4_NT(
weight_ref, qweight, scale, meta = gen_sparse_quant4_NT(
m, k, groupsize=groupsize, device=dev
)
fp16_output = torch.matmul(x, weight_ref)
Expand All @@ -20,11 +20,13 @@ def run_problem(self, m: int, n: int, k: int, groupsize=-1, dev="cuda"):
0.002,
)
except torch.cuda.OutOfMemoryError as e:
print("Out of memory, skipping")
print(f"Out of memory, skipping m={m}, n={n}, k={k}")

def test_tiny(self):
self.run_problem(21504*2, 4096, 21504*2, groupsize=-1)
self.run_problem(256, 16, 256, groupsize=-1)
self.run_problem(256, 16, 512, groupsize=-1)


def test_llama(self):
bsz = 16
Expand Down
Loading

0 comments on commit bbc55b7

Please sign in to comment.