Skip to content

Commit

Permalink
minor bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
xzyaoi committed Oct 13, 2024
1 parent fece40b commit 87fc48e
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
8 changes: 5 additions & 3 deletions tests/ops/test_sbmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def run_problem(
except torch.cuda.OutOfMemoryError as e:
print(f"Out of memory, skipping nr={nr}, nm={nm}, m={m}, k={k}")
finally:
torch.cuda.empty_cache()
pass
# torch.cuda.empty_cache()

def test_tiny(self):
for i in range(10000):
self.run_problem("uniform", 10, 5, 256, 256)
for i in range(20):
self.run_problem("uniform", 10, 5, 256, 512)
# self.run_problem("uniform", 128, 2, 4096, 12288)
self.run_problem("zipf:1.5", 128, 2, 4096, 12288)

# def test_llama(self):
Expand Down
5 changes: 3 additions & 2 deletions triteia/csrc/ops/triteia_nm_sbmm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ __global__ void SBMM_2_4(
meta + weight_indices * (prob_n / 16) * (prob_k / 8);

const int4 *__restrict__ s_ptr = s + weight_indices * (prob_n / 8);
int4 *__restrict__ C_ptr = C + start * (prob_n / 8);
int4 *C_ptr = C + start * (prob_n / 8);
int *locks_ptr = locks + batch_id * (prob_n / 8);
int thread_m = -1;
int thread_k = -1;
Expand All @@ -769,6 +769,7 @@ __global__ void SBMM_2_4(

for (int i = 0; i < tot_n_blocks; i += 4) {
int thread_n_blocks = tot_n_blocks - i;
count = tot_n - 16 * i;
int par = 1;
if (thread_n_blocks > 4) {
par = (16 * thread_n_blocks - pad) / 64;
Expand Down Expand Up @@ -798,10 +799,10 @@ __global__ void SBMM_2_4(
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) printf("Error: %s\n", cudaGetErrorString(err));
__syncthreads();
A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_n_blocks * (prob_n / 8) * par;
}
C_ptr = C + start * (prob_n / 8);
}
};

Expand Down
4 changes: 1 addition & 3 deletions triteia/python/ops/matmul/sbmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ def sbmm_4bit_2_4_native(qweights, xs, metas, ss, indices, base_weight=None):
y = torch.zeros(xs.shape[0], ss.shape[2], dtype=xs.dtype, device=xs.device)
if torch.all(indices == -1):
return y

unique_indices, counts = torch.unique_consecutive(indices, return_counts=True)
if len(unique_indices) == 1:
# use a normal matmul
workspace = torch.zeros(
y.shape[1] // 128 * 16, device=xs.device, dtype=torch.int32
)
Expand Down Expand Up @@ -127,7 +125,7 @@ def sbmm_4bit_2_4_native(qweights, xs, metas, ss, indices, base_weight=None):
len(unique_indices), y.shape[1] // 8, device=xs.device, dtype=torch.int32
)
output = torch.zeros(
(xs.shape[0], y.shape[1]), dtype=torch.float16, device=xs.device
(xs.shape[0], ss.shape[2]), dtype=torch.float16, device=xs.device
)
sbmm_2_4(
xs,
Expand Down

0 comments on commit 87fc48e

Please sign in to comment.