Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Oct 13, 2024
1 parent 9a6c523 commit fece40b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions tests/ops/test_sbmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from triteia.python.ops.utils.generator import generate_model_distribution
from triteia.python.ops import gen_batched_sparse_quant4_NT


class TestSBMMOp(unittest.TestCase):
def run_problem(
self,
Expand Down Expand Up @@ -65,12 +64,13 @@ def run_problem(
torch.cuda.empty_cache()

def test_tiny(self):
self.run_problem("uniform", 10, 5, 256, 256)
self.run_problem("zipf:1.5", 128, 2, 4096, 12288)
for i in range(10000):
self.run_problem("uniform", 10, 5, 256, 256)
self.run_problem("zipf:1.5", 128, 2, 4096, 12288)

# def test_llama(self):
# nrs = [16, 32, 64, 128, 256]
# nms = [[2,4,8,16], [2,4,8,16,32], [2,4,8,16,32,64], [2,4,8,16,32,64,128], [2,4,8,16,32,64,128,256]]
# nrs = [8, 16, 32, 64, 128, 256]
# nms = [[1, 2,3, 4, 5,6,7, 8], [2,4,8,16], [2,4,8,16,32], [2,4,8,16,32,64], [2,4,8,16,32,64,128], [2,4,8,16,32,64,128,256]]
# distributions = ["uniform", "zipf:1.5"]
# for _, layers in llama_shapes.items():
# for layer in layers:
Expand Down
8 changes: 4 additions & 4 deletions triteia/csrc/ops/triteia_nm_sbmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -741,15 +741,15 @@ __global__ void SBMM_2_4(
int start = starts_ptr[batch_id];
int count = counts_ptr[batch_id];
int weight_indices = indices_ptr[batch_id];
const int4 *__restrict__ A_ptr = A + start * prob_k / 8;
const int4 *__restrict__ A_ptr = A + start * (prob_k / 8);
const int4 *__restrict__ B_ptr =
B + weight_indices * (prob_k / 16) * (prob_n / 4);
const int4 *__restrict__ meta_ptr =
meta + weight_indices * (prob_n / 16) * (prob_k / 8);

const int4 *__restrict__ s_ptr = s + weight_indices * prob_n / 8;
int4 *__restrict__ C_ptr = C + start * prob_n / 8;
int *locks_ptr = locks + batch_id * prob_n / 8;
const int4 *__restrict__ s_ptr = s + weight_indices * (prob_n / 8);
int4 *__restrict__ C_ptr = C + start * (prob_n / 8);
int *locks_ptr = locks + batch_id * (prob_n / 8);
int thread_m = -1;
int thread_k = -1;

Expand Down

0 comments on commit fece40b

Please sign in to comment.