diff --git a/docs/examples/01_mm.py b/docs/examples/01_mm.py new file mode 100644 index 0000000..8637489 --- /dev/null +++ b/docs/examples/01_mm.py @@ -0,0 +1,23 @@ +import torch +from triteia.python.ops import gen_sparse_quant4_NT, matmul_4bit_2_4 + +dev = "cuda" +n=1 +m=256 +k=512 +groupsize = -1 + +# x (1, 512) +# weight_ref (512, 256) +# qweight (16, 512) --> (512, 256) +x = torch.randn((n, k), dtype=torch.float16, device=dev) +weight_ref, qweight, scale, meta = gen_sparse_quant4_NT( + m, k, groupsize=groupsize, device=dev +) +# weight_ref = weight_ref.permute(0, 2, 1) +fp16_output = torch.matmul(x, weight_ref) +qs_output = matmul_4bit_2_4(qweight, x, meta, scale) +print(f"weight_ref: {weight_ref.shape}, qweight: {qweight.shape}, scale: {scale.shape}, meta: {meta.shape}") +print(fp16_output) +print(qs_output) +torch.cuda.synchronize() diff --git a/docs/examples/02_bmm.py b/docs/examples/02_bmm.py index 5dc94f7..00da2ba 100644 --- a/docs/examples/02_bmm.py +++ b/docs/examples/02_bmm.py @@ -1,5 +1,9 @@ import torch -from triteia.python.ops import bmm_4bit_2_4_forloop, gen_batched_sparse_quant4_NT +from triteia.python.ops import ( + bmm_4bit_2_4, + bmm_4bit_2_4_forloop, + gen_batched_sparse_quant4_NT +) dev = "cuda" b=16 @@ -8,14 +12,16 @@ p=512 groupsize = -1 -x = torch.randn((b, 1, m), dtype=torch.float16, device=dev) +x = torch.randn((b,1, p), dtype=torch.float16, device=dev) weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT( b, m, p, groupsize=groupsize, device=dev ) # weight_ref = weight_ref.permute(0, 2, 1) fp16_output = torch.bmm(x, weight_ref) -qs_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale) -print(f"weight_ref: {weight_ref.shape}, qweight: {qweight.shape}, scale: {scale.shape}, meta: {meta.shape}") +forloop_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale) +native_output = bmm_4bit_2_4(qweight, x, meta, scale) print(fp16_output) -print(qs_output) +print(forloop_output) +print(native_output) +print(f"native_output: {native_output.shape}, fp16_output: {fp16_output.shape}, forloop_output: {forloop_output.shape}") torch.cuda.synchronize() diff --git a/setup.py b/setup.py index 622f9a4..2d0f7bc 100644 --- a/setup.py +++ b/setup.py @@ -45,9 +45,8 @@ def read_requirements(path): cpp_extension.CUDAExtension( "triteia_cuda", [ - "triteia/csrc/ops/marlin_ops.cpp", + "triteia/csrc/ops/ops.cpp", "triteia/csrc/ops/marlin_nm.cu", - "triteia/csrc/ops/triteia_ops.cpp", "triteia/csrc/ops/triteia_nm_bmm.cu", ], dlink=True, diff --git a/tests/ops/test_bmm.py b/tests/ops/test_bmm.py index 1f993fe..8c7fd48 100644 --- a/tests/ops/test_bmm.py +++ b/tests/ops/test_bmm.py @@ -1,6 +1,10 @@ import torch import unittest -from triteia.python.ops import bmm_4bit_2_4_forloop, gen_batched_sparse_quant4_NT +from triteia.python.ops import ( + bmm_4bit_2_4_forloop, + gen_batched_sparse_quant4_NT, + bmm_4bit_2_4, +) from triteia.python.configs.models.llama import llama_shapes class TestMatmulOp(unittest.TestCase): @@ -12,19 +16,38 @@ def run_problem(self, b: int, m: int, n: int, k: int, groupsize=-1, dev="cuda"): b, m, k, groupsize=groupsize, device=dev ) fp16_output = torch.matmul(x, weight_ref) - qs_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale) + forloop_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale) + native_output = bmm_4bit_2_4(qweight, x, meta, scale) torch.cuda.synchronize() - self.assertLess( - torch.mean(torch.abs(qs_output - fp16_output)) + torch.mean(torch.abs(forloop_output - fp16_output)) + / torch.mean(torch.abs(fp16_output)), + 0.002, + ) + self.assertLess( + torch.mean(torch.abs(native_output - fp16_output)) / torch.mean(torch.abs(fp16_output)), 0.002, ) except torch.cuda.OutOfMemoryError as e: - print("Out of memory, skipping") + print(f"Out of memory, skipping b={b} m={m}, n={n}, k={k}") def test_tiny(self): self.run_problem(16, 256, 16, 256, groupsize=-1) + self.run_problem(16, 512, 16, 512, groupsize=-1) + self.run_problem(16, 256, 16, 512, groupsize=-1) + self.run_problem(16, 512, 16, 256, groupsize=-1) + self.run_problem(8, 256, 16, 256, groupsize=-1) + self.run_problem(4, 512, 16, 512, groupsize=-1) + self.run_problem(4, 256, 16, 512, groupsize=-1) + self.run_problem(8, 512, 16, 256, groupsize=-1) + + def test_llama(self): + bszs = [4, 8, 16] + for _, layers in llama_shapes.items(): + for layer in layers: + for bsz in bszs: + self.run_problem(bsz, layer[1], 16, layer[0]) if __name__ == "__main__": diff --git a/triteia/csrc/ops/marlin_ops.cpp b/triteia/csrc/ops/marlin_ops.cpp deleted file mode 100644 index 89026cd..0000000 --- a/triteia/csrc/ops/marlin_ops.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include -#include -#include -#include - -namespace marlin { -int marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, - void *s, int prob_m, int prob_n, int prob_k, - void *workspace, int groupsize = -1, int dev = 0, - cudaStream_t stream = 0, int thread_k = -1, - int thread_m = -1, int sms = -1, int max_par = 16); -const int ERR_PROB_SHAPE = 1; -const int ERR_KERN_SHAPE = 2; -void mul_2_4(const torch::Tensor &A, const torch::Tensor &B, - const torch::Tensor &meta, torch::Tensor &C, - const torch::Tensor &s, torch::Tensor &workspace, - int thread_k = -1, int thread_m = -1, int sms = -1, - int max_par = 8) { - int prob_n = A.size(0); - int prob_m = C.size(1); - int prob_k = A.size(1); - int groupsize = (s.size(0) == 1) ? -1 : prob_k / 2 / s.size(0); - // printf("groupsize is:%d\n", groupsize); - if (groupsize != -1 && groupsize * s.size(0) != (prob_k / 2)) - AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); - if (workspace.numel() < prob_m / 128 * max_par) - AT_ERROR("workspace must be of size at least ", prob_m / 128 * max_par, - "."); - int dev = A.get_device(); - int err = marlin_cuda_2_4( - A.data_ptr(), B.data_ptr(), meta.data_ptr(), C.data_ptr(), s.data_ptr(), - prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); - if (err == ERR_PROB_SHAPE) { - AT_ERROR("Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", - " not compatible with thread_k=", thread_k, - ", thread_m=", thread_m, "."); - } else if (err == ERR_KERN_SHAPE) { - AT_ERROR("No kernel implementation for thread_k=", thread_k, - ", thread_m=", thread_m, ", groupsize=", groupsize, "."); - } -} -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("mul_2_4", &mul_2_4, "Marlin FP16xINT4 matmul with 2:4 sparsity."); -} -} \ No newline at end of file diff --git a/triteia/csrc/ops/ops.cpp b/triteia/csrc/ops/ops.cpp new file mode 100644 index 0000000..bb75500 --- /dev/null +++ b/triteia/csrc/ops/ops.cpp @@ -0,0 +1,98 @@ +#include +#include +#include +#include + +namespace marlin { +int marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, + void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int groupsize = -1, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16); +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; +void mul_2_4(const torch::Tensor &A, const torch::Tensor &B, + const torch::Tensor &meta, torch::Tensor &C, + const torch::Tensor &s, torch::Tensor &workspace, + int thread_k = -1, int thread_m = -1, int sms = -1, + int max_par = 8) { + int prob_n = A.size(0); + int prob_m = C.size(1); + int prob_k = A.size(1); + int groupsize = (s.size(0) == 1) ? -1 : prob_k / 2 / s.size(0); + // printf("groupsize is:%d\n", groupsize); + if (groupsize != -1 && groupsize * s.size(0) != (prob_k / 2)) + AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); + if (workspace.numel() < prob_m / 128 * max_par) + AT_ERROR("workspace must be of size at least ", prob_m / 128 * max_par, + "."); + int dev = A.get_device(); + int err = marlin_cuda_2_4( + A.data_ptr(), B.data_ptr(), meta.data_ptr(), C.data_ptr(), s.data_ptr(), + prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); + if (err == ERR_PROB_SHAPE) { + AT_ERROR("Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", + " not compatible with thread_k=", thread_k, + ", thread_m=", thread_m, "."); + } else if (err == ERR_KERN_SHAPE) { + AT_ERROR("No kernel implementation for thread_k=", thread_k, + ", thread_m=", thread_m, ", groupsize=", groupsize, "."); + } +} +} // namespace marlin + +namespace triteia { +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; +int triteia_cuda_bmm_2_4(const void *A, const void *B, const void *meta, + void *C, void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int groupsize = -1, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16); + +void bmm_2_4(const torch::Tensor &A, const torch::Tensor &B, + const torch::Tensor &meta, torch::Tensor &C, + const torch::Tensor &s, torch::Tensor &workspace, + int thread_k = -1, int thread_m = -1, int sms = -1, + int max_par = 8) { + /** + * A: [n, k]: n: #reqs, k: in features + * B: [n, k/16, 2*m]: n: #reqs, k: in features, m: out features + * C: [n, m]: n: #reqs, m: out features + * s: [n, 1, m]: n: #reqs, m: out features + * meta: [n, k, m/16]: n: #reqs, k: in features, m: out features + */ + + int prob_n = A.size(0); + int prob_m = C.size(2); + int prob_k = A.size(2); + + int groupsize = (s.size(1) == 1) ? -1 : prob_k / s.size(1); + if (groupsize != -1 && groupsize * s.size(1) != prob_k) + AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); + if (workspace.numel() < prob_n * prob_m / 128 * max_par) + AT_ERROR("workspace must be of size at least ", prob_n * prob_m / 128 * max_par, + "."); + int dev = A.get_device(); + int err = triteia_cuda_bmm_2_4( + A.data_ptr(), B.data_ptr(), meta.data_ptr(), C.data_ptr(), s.data_ptr(), + prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); + if (err == ERR_PROB_SHAPE) { + AT_ERROR("Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", + " not compatible with thread_k=", thread_k, + ", thread_m=", thread_m, "."); + } else if (err == ERR_KERN_SHAPE) { + AT_ERROR("No kernel implementation for thread_k=", thread_k, + ", thread_m=", thread_m, ", groupsize=", groupsize, "."); + } +} + +} // namespace triteia + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("mul_2_4", &marlin::mul_2_4, + "Marlin FP16xINT4 matmul with 2:4 sparsity."); + m.def("bmm_2_4", &triteia::bmm_2_4, "FP16xINT4 bmm with 2:4 sparsity."); +} \ No newline at end of file diff --git a/triteia/csrc/ops/triteia_nm_bmm.cu b/triteia/csrc/ops/triteia_nm_bmm.cu index f7f5a89..9977b7e 100644 --- a/triteia/csrc/ops/triteia_nm_bmm.cu +++ b/triteia/csrc/ops/triteia_nm_bmm.cu @@ -3,12 +3,704 @@ #include #include #include + #include #include "marlin.cuh" #include "triteia.cuh" +using namespace marlin; + namespace triteia { +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ void marlin_2_4_internal( + const int4 *__restrict__ A, // fp16 input matrix of shape r x k + const int4 + *__restrict__ B, // 4bit quantized weight matrix of shape r x k x n + const int4 + *__restrict__ meta, // 2bit metadata information about 2:4 format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = + prob_k / 32 / thread_k_blocks; // number of thread_k_blocks in k-dim + int n_tiles = + prob_n / 16 / thread_n_blocks; // number of thread_n_blocks in n-dim + int iters = ceildiv(k_tiles * n_tiles * parallel, + gridDim.x); // iters needeed to cover all slices + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + // if (group_blocks != -1) + // iters = (group_blocks / thread_k_blocks) * + // ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + B += + (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n * prob_k / 64; + meta += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n * prob_k / + 128; + s += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + + B += 16 * thread_m_blocks * prob_n * prob_k / 64; + meta += 16 * thread_m_blocks * prob_n * prob_k / 128; + s += 16 * thread_m_blocks * prob_k / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 32 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 32 * thread_k_blocks / + 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 4 * ((threads / 32) / + (thread_n_blocks / + 4)); // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependicies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4 *meta_ptr[m_sh_iters]; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); + int4 *sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], + meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + // Load meta with ldsm4 + int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + +// Parallel logarithmic shared memory reduction. We make sure to avoid any +// unnecessary read or write iterations, e.g., for two warps we write only once +// by warp 1 and read only once by warp 0. +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped partioning + // minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + int col = 2 * ((threadIdx.x % 32) % 4); + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, + float c4, float c5, float c6, float c7, FragS &s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2 *tmp = (half2 *)&res; + if (group_blocks == + -1) { // for per-column quantization we finally apply the scale here + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4 *)sh)[idx] = *((int4 *)&res[0]); + }; + + if (threadIdx.x / 32 < + thread_n_blocks / 4) { // RLC: only warp 0 and 1 baseline example +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + // if((col+1)(); + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + matmul(pipe); + + pipe++; + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} template - <<>>( - A_ptr, B_ptr, meta_ptr, C_ptr, s_ptr, 1, prob_n, prob_k, locks_ptr); + const int possible_thread_m_blocks = 1; + marlin_2_4_internal( + A_ptr, B_ptr, meta_ptr, C_ptr, s_ptr, 1, prob_n, prob_k, locks_ptr); } }; @@ -112,12 +798,10 @@ int triteia_cuda_bmm_2_4(const void *A, const void *B, const void *meta, int cols = prob_m / thread_m; int *locks = (int *)workspace; int ret = 0; - printf("prob_m: %d, prob_n: %d, prob_k: %d\n", prob_m, prob_n, prob_k); for (int i = 0; i < tot_n_blocks; i += 4) { int thread_n_blocks = tot_n_blocks - i; prob_n = tot_n - 16 * i; int par = 1; - printf("thread_n_blocks: %d\n", thread_n_blocks); if (thread_n_blocks > 4) { // Note that parallel > 1 currently only works for inputs without any // padding @@ -150,5 +834,5 @@ int triteia_cuda_bmm_2_4(const void *A, const void *B, const void *meta, } return ret; } -#endif // TRITEIA_CUDA_KERNEL_CUH_ +#endif // TRITEIA_CUDA_KERNEL_CUH_ } // namespace triteia diff --git a/triteia/csrc/ops/triteia_ops.cpp b/triteia/csrc/ops/triteia_ops.cpp deleted file mode 100644 index 20f78a8..0000000 --- a/triteia/csrc/ops/triteia_ops.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include -#include -#include -#include - -namespace triteia { -const int ERR_PROB_SHAPE = 1; -const int ERR_KERN_SHAPE = 2; -int triteia_cuda_bmm_2_4(const void *A, const void *B, const void *meta, void *C, - void *s, int prob_m, int prob_n, int prob_k, - void *workspace, int groupsize = -1, int dev = 0, - cudaStream_t stream = 0, int thread_k = -1, - int thread_m = -1, int sms = -1, int max_par = 16); - - - -void bmm_2_4(const torch::Tensor &A, const torch::Tensor &B, - const torch::Tensor &meta, torch::Tensor &C, - const torch::Tensor &s, torch::Tensor &workspace, - int thread_k = -1, int thread_m = -1, int sms = -1, - int max_par = 8) { - /** - * A: [n, k]: n: #reqs, k: in features - * B: [n, k/16, 2*m]: n: #reqs, k: in features, m: out features - * C: [n, m]: n: #reqs, m: out features - * s: [n, 1, m]: n: #reqs, m: out features - * meta: [n, k, m/16]: n: #reqs, k: in features, m: out features - */ - - int prob_n = A.size(0); - int prob_m = C.size(1); - int prob_k = A.size(1); - - int groupsize = (s.size(1) == 1) ? -1 : prob_k / s.size(1); - if (groupsize != -1 && groupsize * s.size(1) != prob_k) - AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); - if (workspace.numel() < prob_n / 128 * max_par) - AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, - "."); - int dev = A.get_device(); - int err = triteia_cuda_bmm_2_4( - A.data_ptr(), B.data_ptr(), meta.data_ptr(), C.data_ptr(), s.data_ptr(), - prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); - if (err == ERR_PROB_SHAPE) { - AT_ERROR("Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", - " not compatible with thread_k=", thread_k, - ", thread_m=", thread_m, "."); - } else if (err == ERR_KERN_SHAPE) { - AT_ERROR("No kernel implementation for thread_k=", thread_k, - ", thread_m=", thread_m, ", groupsize=", groupsize, "."); - } -} - -} // namespace triteia \ No newline at end of file diff --git a/triteia/python/ops/__init__.py b/triteia/python/ops/__init__.py index 5bab041..f99c3f6 100644 --- a/triteia/python/ops/__init__.py +++ b/triteia/python/ops/__init__.py @@ -1,11 +1,12 @@ from .matmul.sparse_low_precision import matmul_4bit_2_4 -from .matmul.bmm import bmm_4bit_2_4_forloop +from .matmul.bmm import bmm_4bit_2_4_forloop, bmm_4bit_2_4 from .utils.sparsity import mask_creator from .utils.generator import gen_sparse_quant4_NT, gen_batched_sparse_quant4_NT __all__ = [ "matmul_4bit_2_4", + "bmm_4bit_2_4", "bmm_4bit_2_4_forloop" "mask_creator", "gen_sparse_quant4_NT", diff --git a/triteia/python/ops/matmul/bmm.py b/triteia/python/ops/matmul/bmm.py index 0ccad71..65b5440 100644 --- a/triteia/python/ops/matmul/bmm.py +++ b/triteia/python/ops/matmul/bmm.py @@ -1,4 +1,5 @@ import torch +from triteia.python.capi.marlin import bmm_2_4 from .sparse_low_precision import matmul_4bit_2_4 def bmm_4bit_2_4_forloop(qweights, xs, metas, ss): @@ -14,4 +15,17 @@ def bmm_4bit_2_4_forloop(qweights, xs, metas, ss): for id in range(xs.shape[0]): outputs[id] = matmul_4bit_2_4(qweights[id], xs[id], metas[id], ss[id]) return outputs - \ No newline at end of file + +def bmm_4bit_2_4(qweights, xs, metas, ss): + """ + Batched Low precision sparse matrix multiplcation with 2:4 sparsity. + ---- + Parameters: + + """ + C = torch.zeros( + (xs.shape[0], xs.shape[1], ss.shape[2]), dtype=xs.dtype, device=xs.device + ) + workspace = torch.zeros(xs.shape[0] * max(ss.shape[2], ss.shape[1]), device=xs.device, dtype=torch.int32) + bmm_2_4(xs, qweights, metas, C, ss, workspace) + return C \ No newline at end of file diff --git a/triteia/python/ops/utils/generator.py b/triteia/python/ops/utils/generator.py index 3cb600d..5c9a583 100644 --- a/triteia/python/ops/utils/generator.py +++ b/triteia/python/ops/utils/generator.py @@ -56,7 +56,7 @@ def gen_batched_sparse_quant4_NT(b, m, n, groupsize=-1, device="cuda:0"): uncompressed = [] for i in range(b): unc, q, s, meta = gen_sparse_quant4_NT(m, n, groupsize=groupsize, device=device) - uncompressed.append(unc.t()) + uncompressed.append(unc) qs.append(q) scales.append(s) metas.append(meta)