From 87fc48ed4b325b4b5f4402cc2e9930310ac2344a Mon Sep 17 00:00:00 2001 From: Xiaozhe Yao Date: Sun, 13 Oct 2024 23:41:11 +0000 Subject: [PATCH] minor bug fix --- tests/ops/test_sbmm.py | 8 +++++--- triteia/csrc/ops/triteia_nm_sbmm.cu | 5 +++-- triteia/python/ops/matmul/sbmm.py | 4 +--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/ops/test_sbmm.py b/tests/ops/test_sbmm.py index bddad63..922c98b 100644 --- a/tests/ops/test_sbmm.py +++ b/tests/ops/test_sbmm.py @@ -61,11 +61,13 @@ def run_problem( except torch.cuda.OutOfMemoryError as e: print(f"Out of memory, skipping nr={nr}, nm={nm}, m={m}, k={k}") finally: - torch.cuda.empty_cache() + pass + # torch.cuda.empty_cache() def test_tiny(self): - for i in range(10000): - self.run_problem("uniform", 10, 5, 256, 256) + for i in range(20): + self.run_problem("uniform", 10, 5, 256, 512) + # self.run_problem("uniform", 128, 2, 4096, 12288) self.run_problem("zipf:1.5", 128, 2, 4096, 12288) # def test_llama(self): diff --git a/triteia/csrc/ops/triteia_nm_sbmm.cu b/triteia/csrc/ops/triteia_nm_sbmm.cu index 3c99ace..8d6fc92 100644 --- a/triteia/csrc/ops/triteia_nm_sbmm.cu +++ b/triteia/csrc/ops/triteia_nm_sbmm.cu @@ -748,7 +748,7 @@ __global__ void SBMM_2_4( meta + weight_indices * (prob_n / 16) * (prob_k / 8); const int4 *__restrict__ s_ptr = s + weight_indices * (prob_n / 8); - int4 *__restrict__ C_ptr = C + start * (prob_n / 8); + int4 *C_ptr = C + start * (prob_n / 8); int *locks_ptr = locks + batch_id * (prob_n / 8); int thread_m = -1; int thread_k = -1; @@ -769,6 +769,7 @@ __global__ void SBMM_2_4( for (int i = 0; i < tot_n_blocks; i += 4) { int thread_n_blocks = tot_n_blocks - i; + count = tot_n - 16 * i; int par = 1; if (thread_n_blocks > 4) { par = (16 * thread_n_blocks - pad) / 64; @@ -798,10 +799,10 @@ __global__ void SBMM_2_4( } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) printf("Error: %s\n", cudaGetErrorString(err)); - __syncthreads(); A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; C_ptr += 16 * thread_n_blocks * (prob_n / 8) * par; } + C_ptr = C + start * (prob_n / 8); } }; diff --git a/triteia/python/ops/matmul/sbmm.py b/triteia/python/ops/matmul/sbmm.py index fb94d8f..844bd51 100644 --- a/triteia/python/ops/matmul/sbmm.py +++ b/triteia/python/ops/matmul/sbmm.py @@ -93,10 +93,8 @@ def sbmm_4bit_2_4_native(qweights, xs, metas, ss, indices, base_weight=None): y = torch.zeros(xs.shape[0], ss.shape[2], dtype=xs.dtype, device=xs.device) if torch.all(indices == -1): return y - unique_indices, counts = torch.unique_consecutive(indices, return_counts=True) if len(unique_indices) == 1: - # use a normal matmul workspace = torch.zeros( y.shape[1] // 128 * 16, device=xs.device, dtype=torch.int32 ) @@ -127,7 +125,7 @@ def sbmm_4bit_2_4_native(qweights, xs, metas, ss, indices, base_weight=None): len(unique_indices), y.shape[1] // 8, device=xs.device, dtype=torch.int32 ) output = torch.zeros( - (xs.shape[0], y.shape[1]), dtype=torch.float16, device=xs.device + (xs.shape[0], ss.shape[2]), dtype=torch.float16, device=xs.device ) sbmm_2_4( xs,