diff --git a/docs/examples/02_bmm.py b/docs/examples/02_bmm.py new file mode 100644 index 0000000..5dc94f7 --- /dev/null +++ b/docs/examples/02_bmm.py @@ -0,0 +1,21 @@ +import torch +from triteia.python.ops import bmm_4bit_2_4_forloop, gen_batched_sparse_quant4_NT + +dev = "cuda" +b=16 +n=1 +m=256 +p=512 +groupsize = -1 + +x = torch.randn((b, 1, m), dtype=torch.float16, device=dev) +weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT( + b, m, p, groupsize=groupsize, device=dev +) +# weight_ref = weight_ref.permute(0, 2, 1) +fp16_output = torch.bmm(x, weight_ref) +qs_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale) +print(f"weight_ref: {weight_ref.shape}, qweight: {qweight.shape}, scale: {scale.shape}, meta: {meta.shape}") +print(fp16_output) +print(qs_output) +torch.cuda.synchronize() diff --git a/setup.py b/setup.py index 1c2ed86..622f9a4 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ def read(*paths, **kwargs): content = open_file.read().strip() return content + def read_requirements(path): return [ line.strip() @@ -28,6 +29,7 @@ def read_requirements(path): if not line.startswith(('"', "#", "-", "git+")) ] + setup( name="triteia", version=read("triteia", "VERSION"), @@ -41,14 +43,20 @@ def read_requirements(path): extras_require={"test": read_requirements("requirements-dev.txt")}, ext_modules=[ cpp_extension.CUDAExtension( - "marlin_cuda", + "triteia_cuda", [ - "triteia/csrc/ops/ops.cpp", + "triteia/csrc/ops/marlin_ops.cpp", "triteia/csrc/ops/marlin_nm.cu", + "triteia/csrc/ops/triteia_ops.cpp", + "triteia/csrc/ops/triteia_nm_bmm.cu", ], + dlink=True, extra_compile_args={ - "nvcc": ["-O3", "-arch=sm_86", "--ptxas-options=-v", "-lineinfo"] + "nvcc": [ + "-O3", "-arch=sm_86", "--ptxas-options=-v", "-dc", "-lineinfo" + ] }, + extra_link_args=["-lcudadevrt","-lcudart"], ), ], cmdclass={"build_ext": cpp_extension.BuildExtension}, diff --git a/tests/ops/test_bmm.py b/tests/ops/test_bmm.py new file mode 100644 index 0000000..1f993fe --- /dev/null +++ b/tests/ops/test_bmm.py @@ -0,0 +1,31 @@ +import torch +import unittest +from triteia.python.ops import bmm_4bit_2_4_forloop, gen_batched_sparse_quant4_NT +from triteia.python.configs.models.llama import llama_shapes + +class TestMatmulOp(unittest.TestCase): + def run_problem(self, b: int, m: int, n: int, k: int, groupsize=-1, dev="cuda"): + try: + print(f"Running bmm problem with b={b} m={m}, n={n}, k={k}") + x = torch.randn((b, 1, k), dtype=torch.float16, device=dev) + weight_ref, qweight, scale, meta = gen_batched_sparse_quant4_NT( + b, m, k, groupsize=groupsize, device=dev + ) + fp16_output = torch.matmul(x, weight_ref) + qs_output = bmm_4bit_2_4_forloop(qweight, x, meta, scale) + torch.cuda.synchronize() + + self.assertLess( + torch.mean(torch.abs(qs_output - fp16_output)) + / torch.mean(torch.abs(fp16_output)), + 0.002, + ) + except torch.cuda.OutOfMemoryError as e: + print("Out of memory, skipping") + + def test_tiny(self): + self.run_problem(16, 256, 16, 256, groupsize=-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/ops/test_matmul.py b/tests/ops/test_matmul.py index 1ec74fe..28236a3 100644 --- a/tests/ops/test_matmul.py +++ b/tests/ops/test_matmul.py @@ -1,14 +1,14 @@ import torch import unittest -from triteia.python.ops import matmul_4bit_2_4, gen_quant4_NT +from triteia.python.ops import matmul_4bit_2_4, gen_sparse_quant4_NT from triteia.python.configs.models.llama import llama_shapes class TestMatmulOp(unittest.TestCase): def run_problem(self, m: int, n: int, k: int, groupsize=-1, dev="cuda"): try: - print(f"Running problem with m={m}, n={n}, k={k}") + print(f"Running mm problem with m={m}, n={n}, k={k}") x = torch.randn((n, k), dtype=torch.float16, device=dev) - weight_ref, qweight, scale, meta = gen_quant4_NT( + weight_ref, qweight, scale, meta = gen_sparse_quant4_NT( m, k, groupsize=groupsize, device=dev ) fp16_output = torch.matmul(x, weight_ref) @@ -20,11 +20,13 @@ def run_problem(self, m: int, n: int, k: int, groupsize=-1, dev="cuda"): 0.002, ) except torch.cuda.OutOfMemoryError as e: - print("Out of memory, skipping") + print(f"Out of memory, skipping m={m}, n={n}, k={k}") def test_tiny(self): self.run_problem(21504*2, 4096, 21504*2, groupsize=-1) self.run_problem(256, 16, 256, groupsize=-1) + self.run_problem(256, 16, 512, groupsize=-1) + def test_llama(self): bsz = 16 diff --git a/triteia/csrc/ops/marlin.cuh b/triteia/csrc/ops/marlin.cuh new file mode 100644 index 0000000..3eadb5d --- /dev/null +++ b/triteia/csrc/ops/marlin.cuh @@ -0,0 +1,696 @@ +#pragma once + +#include +#include +#include + +#include + +#include "common/base.h" +#include "common/mem.h" +#include "common/mma.h" + +namespace marlin { + +template shared + // fetch pipeline + const int group_blocks // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ inline void Marlin_2_4( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + const int4 + *__restrict__ meta, // 2bit metadata information about 2:4 format on B + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = + prob_k / 32 / thread_k_blocks; // number of thread_k_blocks in k-dim + int n_tiles = + prob_n / 16 / thread_n_blocks; // number of thread_n_blocks in n-dim + int iters = ceildiv(k_tiles * n_tiles * parallel, + gridDim.x); // iters needeed to cover all slices + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 32 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 32 * thread_k_blocks / + 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 4 * ((threads / 32) / + (thread_n_blocks / + 4)); // between shared memory tile reads //RLC: 2 * #warps k-dim + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 + constexpr int m_sh_stride = + (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp + int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; + int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); + constexpr int m_sh_wr_delta = threads / 2; + constexpr int m_sh_rd_delta = threads / 2; + constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; + constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + + (threadIdx.x % (m_sh_stride)); + m_gl_rd += (m_sh_stride)*slice_col; + m_gl_rd += m_gl_rd_delta_o * slice_row; + int m_sh_wr = threadIdx.x; + int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[0][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[1][i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependicies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; + const int4 *meta_ptr[m_sh_iters]; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); + int4 *sh_m = sh_s + (stages * s_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks][2]; + I4 frag_b_quant[2]; + FragM frag_m[2][2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) { + if (m_sh_wr_pred) + cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], + meta_ptr[i]); + meta_ptr[i] += m_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + ldsm4(frag_a[k % 2][i][0], + &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); + ldsm4(frag_a[k % 2][i][1], + &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); + } + + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + + // Load meta with ldsm4 + int4 *sh_m_stage = sh_m + m_sh_stage * pipe; + ldsm4_m(frag_m[k % 2][0], + &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], + frag_m[k % 2][j / 2], j % 2); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + +// Parallel logarithmic shared memory reduction. We make sure to avoid any +// unnecessary read or write iterations, e.g., for two warps we write only once +// by warp 1 and read only once by warp 0. +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped partioning + // minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; + int c_gl_wr_delta_i = + c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) + int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int col = 2 * ((threadIdx.x % 32) % 4); + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + col + (i % 2) < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2] += + __half2float( + reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); + } + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j2 = 0; j2 < 2; j2++) { +#pragma unroll + for (int j1 = 0; j1 < 4; j1++) { + reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + + 4 * ((i % 4) / 2) + i % 2]); + } + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + + constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: + constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: + constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: + + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + + int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + + ((threadIdx.x % 32) / 4); // RLC: + c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) + + constexpr int c_sh_rd_delta = + c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: + int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + + (threadIdx.x % (2 * 2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + int col = 2 * ((threadIdx.x % 32) % 4); + + auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, + float c4, float c5, float c6, float c7, FragS &s1) { + uint2 res[2]; + res[0] = to_half4(c0, c1, c2, c3); + res[1] = to_half4(c4, c5, c6, c7); + half2 *tmp = (half2 *)&res; + if (group_blocks == + -1) { // for per-column quantization we finally apply the scale here + tmp[0] = __hmul2(tmp[0], s0[0]); + tmp[1] = __hmul2(tmp[1], s0[1]); + tmp[2] = __hmul2(tmp[2], s1[0]); + tmp[3] = __hmul2(tmp[3], s1[1]); + } + ((int4 *)sh)[idx] = *((int4 *)&res[0]); + }; + + if (threadIdx.x / 32 < + thread_n_blocks / 4) { // RLC: only warp 0 and 1 baseline example +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + int wr = c_sh_wr; + write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], + frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], + frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], + frag_s[0][2]); + // if((col+1)(); + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + wait_for_stage(); + + fetch_to_registers(pipe + 1, (pipe + 1) % stages); + matmul(pipe); + + pipe++; + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) + meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < m_sh_iters; i++) meta_ptr[i] -= m_gl_stride; + } + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} +} // namespace marlin \ No newline at end of file diff --git a/triteia/csrc/ops/marlin_nm.cu b/triteia/csrc/ops/marlin_nm.cu index 6430bc8..b9a5cde 100644 --- a/triteia/csrc/ops/marlin_nm.cu +++ b/triteia/csrc/ops/marlin_nm.cu @@ -1,711 +1,17 @@ - -#ifndef MARLIN_CUDA_KERNEL_CUH -#define MARLIN_CUDA_KERNEL_CUH +#ifndef MARLIN_NM_CUDA_KERNEL_CUH_ +#define MARLIN_NM_CUDA_KERNEL_CUH_ #include #include #include #include +#include "marlin.cuh" #include "common/base.h" #include "common/mem.h" #include "common/mma.h" namespace marlin { - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with - // a separate quantization scale - > -__global__ void Marlin_2_4( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - const int4 - *__restrict__ meta, // 2bit metadata information about 2:4 format on B - int4 *__restrict__ C, // fp16 output buffer of shape mxn - const int4 - *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = - prob_k / 32 / thread_k_blocks; // number of thread_k_blocks in k-dim - int n_tiles = - prob_n / 16 / thread_n_blocks; // number of thread_n_blocks in n-dim - int iters = ceildiv(k_tiles * n_tiles * parallel, - gridDim.x); // iters needeed to cover all slices - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts in - // the middle of group. - if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // RLC: 8 is vec_size -> 128-bit instructions, 8 fp16 elements - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time - // constant - constexpr int a_sh_stride = - 32 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = - 32 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = - a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = - a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = - 4 * ((threads / 32) / - (thread_n_blocks / - 4)); // between shared memory tile reads //RLC: 2 * #warps k-dim - constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = - ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile - - int b_gl_stride = 16 * prob_n / 32; - constexpr int b_sh_stride = 32 * thread_n_blocks / 4; - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); - constexpr int b_sh_wr_delta = threads; - constexpr int b_sh_rd_delta = threads; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - int m_gl_stride = 2 * prob_n / 8; // (16*2*4 / 8) = 16 - constexpr int m_sh_stride = - (16 * thread_n_blocks) / 4; // #warps n-dim * threads/warp - int m_gl_rd_delta_o = m_gl_stride * thread_k_blocks; - int m_gl_rd_delta_i = m_gl_stride * (threads / m_sh_stride); - constexpr int m_sh_wr_delta = threads / 2; - constexpr int m_sh_rd_delta = threads / 2; - constexpr int m_sh_stage = m_sh_stride * thread_k_blocks; - constexpr int m_sh_iters = ceildiv(m_sh_stage, m_sh_wr_delta); - - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_sh_stage = s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 4 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = - b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x; - int b_sh_rd = threadIdx.x; - - int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) + - (threadIdx.x % (m_sh_stride)); - m_gl_rd += (m_sh_stride)*slice_col; - m_gl_rd += m_gl_rd_delta_o * slice_row; - int m_sh_wr = threadIdx.x; - int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16; - - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; -#pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; -#pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[2][b_sh_wr_iters][thread_m_blocks]; -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { -#pragma unroll - for (int j = 0; j < thread_m_blocks; j++) { - a_sh_rd_trans[0][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - a_sh_rd_trans[1][i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd + 2); - } - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependicies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4 *B_ptr[b_sh_wr_iters]; -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - bool m_sh_wr_pred = threadIdx.x < m_sh_wr_delta; - const int4 *meta_ptr[m_sh_iters]; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] = meta + m_gl_rd_delta_i * i + m_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4 *sh_a = sh; - int4 *sh_b = sh_a + (stages * a_sh_stage); - int4 *sh_s = sh_b + (stages * b_sh_stage); - int4 *sh_m = sh_s + (stages * s_sh_stage); - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks][2]; - I4 frag_b_quant[2]; - FragM frag_m[2][2]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { -#pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); - B_ptr[i] += b_gl_rd_delta_o; - } - int4 *sh_meta_stage = sh_m + m_sh_stage * pipe; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) { - if (m_sh_wr_pred) - cp_async4_stream(&sh_meta_stage[m_sh_wr_delta * i + m_sh_wr], - meta_ptr[i]); - meta_ptr[i] += m_gl_rd_delta_o; - } - // Only fetch scales if this tile starts a new group - if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) - cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); - s_gl_rd += s_gl_rd_delta; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; - // however, this does not seem to be a significant bottleneck, while some - // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticable drop in performance. - if (group_blocks != -1) { - int4 *sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; -#pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - ldsm4(frag_a[k % 2][i][0], - &sh_a_stage[a_sh_rd_trans[0][k % b_sh_wr_iters][i]]); - ldsm4(frag_a[k % 2][i][1], - &sh_a_stage[a_sh_rd_trans[1][k % b_sh_wr_iters][i]]); - } - - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); - - // Load meta with ldsm4 - int4 *sh_m_stage = sh_m + m_sh_stage * pipe; - ldsm4_m(frag_m[k % 2][0], - &sh_m_stage[m_sh_rd_delta * (k % m_sh_iters) + m_sh_rd]); - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { -// We have the m dimension as the inner loop in order to encourage overlapping -// dequantization and matmul operations. -#pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; - FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can - // avoid doing so for each weight. - if (group_blocks != -1) - scale(frag_b0, frag_s[k % 2][j], 0); - FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) - scale(frag_b1, frag_s[k % 2][j], 1); -#pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma_sp(frag_b0, frag_b1, frag_a[k % 2][i][0], frag_c[i][j][0], - frag_m[k % 2][j / 2], j % 2); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); - -// Parallel logarithmic shared memory reduction. We make sure to avoid any -// unnecessary read or write iterations, e.g., for two warps we write only once -// by warp 1 and read only once by warp 0. -#pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { -#pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { -#pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float *c_rd = reinterpret_cast( - &sh[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh[red_sh_wr]); -#pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { -#pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float *c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); -#pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partioning - // minimizes the number of such reductions and our outputs are usually rather - // small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 2 * 4 * c_gl_stride; - int c_gl_wr_delta_i = - c_gl_stride; // 8 threads (e.g., 0,4,8,12,16,20,24,28) - int c_gl_wr = 2 * c_gl_stride * (threadIdx.x % 4) + - 8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int col = 2 * ((threadIdx.x % 32) % 4); - - if (!first) { -// Interestingly, doing direct global accesses here really seems to mess up the -// compiler and lead to slowdowns, hence we also use async-copies even though -// these fetches are not actually asynchronous. -#pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - -#pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || - 8 * (i / 2) + col + (i % 2) < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; -#pragma unroll - for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2] += - __half2float( - reinterpret_cast<__half *>(&c_red)[(j2 * 4 + j1)]); - } - } - } - if (!last) { - int4 c; -#pragma unroll - for (int j2 = 0; j2 < 2; j2++) { -#pragma unroll - for (int j1 = 0; j1 < 4; j1++) { - reinterpret_cast<__half *>(&c)[(j2 * 4 + j1)] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 8 * j1 + 2 * j2 + - 4 * ((i % 4) / 2) + i % 2]); - } - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - - constexpr int c_sh_stride = 2 * thread_n_blocks; // RLC: - constexpr int c_sh_stride_2 = 2 * c_sh_stride + 2; // RLC: - constexpr int c_sh_stride_3 = 2 * (2 * thread_n_blocks) + 2; // RLC: - - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - - int c_sh_wr = c_sh_stride_2 * ((threadIdx.x % 32) % 4) + - ((threadIdx.x % 32) / 4); // RLC: - c_sh_wr += 8 * (threadIdx.x / 32); // 128/4(half4) - - constexpr int c_sh_rd_delta = - c_sh_stride_3 * (threads / (2 * 2 * thread_n_blocks)); // RLC: - int c_sh_rd = c_sh_stride_3 * (threadIdx.x / (2 * 2 * thread_n_blocks)) + - (threadIdx.x % (2 * 2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - int col = 2 * ((threadIdx.x % 32) % 4); - - auto write = [&](int idx, float c0, float c1, float c2, float c3, FragS &s0, - float c4, float c5, float c6, float c7, FragS &s1) { - uint2 res[2]; - res[0] = to_half4(c0, c1, c2, c3); - res[1] = to_half4(c4, c5, c6, c7); - half2 *tmp = (half2 *)&res; - if (group_blocks == - -1) { // for per-column quantization we finally apply the scale here - tmp[0] = __hmul2(tmp[0], s0[0]); - tmp[1] = __hmul2(tmp[1], s0[1]); - tmp[2] = __hmul2(tmp[2], s1[0]); - tmp[3] = __hmul2(tmp[3], s1[1]); - } - ((int4 *)sh)[idx] = *((int4 *)&res[0]); - }; - - if (threadIdx.x / 32 < - thread_n_blocks / 4) { // RLC: only warp 0 and 1 baseline example -#pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - int wr = c_sh_wr; - write(wr, frag_c[i][0][0][0], frag_c[i][1][0][0], frag_c[i][2][0][0], - frag_c[i][3][0][0], frag_s[0][0], frag_c[i][0][0][2], - frag_c[i][1][0][2], frag_c[i][2][0][2], frag_c[i][3][0][2], - frag_s[0][2]); - // if((col+1)(); - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - wait_for_stage(); - - fetch_to_registers(pipe + 1, (pipe + 1) % stages); - matmul(pipe); - - pipe++; - slice_iters--; - if (slice_iters == 0) - break; - } - a_gl_rd += a_gl_rd_delta_o * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compliation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (group_blocks == -1 && last) { - if (s_sh_wr_pred) - cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); - cp_async_fence(); - } - thread_block_reduce(); - if (group_blocks == -1 && last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - *(float4 *)(frag_s) = *(float4 *)(&sh_s[s_sh_rd]); - } - } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] += (m_sh_stride)-m_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { -#pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; -#pragma unroll - for (int i = 0; i < m_sh_iters; i++) - meta_ptr[i] -= m_gl_stride; - } - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - start_pipes(); - } - } - } -} - // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. @@ -812,6 +118,5 @@ int marlin_cuda_2_4(const void *A, const void *B, const void *meta, void *C, return ret; } - -#endif +#endif } \ No newline at end of file diff --git a/triteia/csrc/ops/ops.cpp b/triteia/csrc/ops/marlin_ops.cpp similarity index 100% rename from triteia/csrc/ops/ops.cpp rename to triteia/csrc/ops/marlin_ops.cpp diff --git a/triteia/csrc/ops/triteia.cuh b/triteia/csrc/ops/triteia.cuh new file mode 100644 index 0000000..8a42bfb --- /dev/null +++ b/triteia/csrc/ops/triteia.cuh @@ -0,0 +1,39 @@ +#pragma once +#include +#include +#include +#include +#include "marlin.cuh" + +namespace triteia { +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define gpuErrchk(ans) \ + { gpuAssert((ans), __FILE__, __LINE__); } + +inline void gpuAssert(cudaError_t code, const char *file, int line, + bool abort = true) { + if (code != cudaSuccess) { + fprintf(stderr, "Device Side Error: %s %s %d\n", cudaGetErrorString(code), file, + line); + if (abort) exit(code); + } +} +#define CALL_MM_2_4(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + marlin::Marlin_2_4 \ + <<>>(A_ptr, B_ptr, meta_ptr, \ + C_ptr, s_ptr, count, prob_n, \ + prob_k, locks_ptr); \ + } +} // namespace triteia \ No newline at end of file diff --git a/triteia/csrc/ops/triteia_nm_bmm.cu b/triteia/csrc/ops/triteia_nm_bmm.cu new file mode 100644 index 0000000..f7f5a89 --- /dev/null +++ b/triteia/csrc/ops/triteia_nm_bmm.cu @@ -0,0 +1,154 @@ +#ifndef TRITEIA_CUDA_KERNEL_CUH_ +#define TRITEIA_CUDA_KERNEL_CUH_ +#include +#include +#include +#include + +#include "marlin.cuh" +#include "triteia.cuh" + +namespace triteia { + +template +__global__ void BMM_2_4( + /** + * A: [n, k]: n: #reqs, k: in features + * B: [n, k/32, 2*m]: n: #reqs, k: in features, m: out features + * C: [n, m]: n: #reqs, m: out features + * s: [n, 1, m]: n: #reqs, m: out features + * meta: [n, k, m/16]: n: #reqs, k: in features, m: out features + */ + const int4 *__restrict__ A, const int4 *__restrict__ B, + const int4 *__restrict__ meta, int4 *__restrict__ C, + const int4 *__restrict__ s, cudaStream_t stream, int blocks, int prob_m, + int prob_n, int prob_k, int *locks, int max_par) { + // printf("prob_m: %d, prob_n: %d, prob_k: %d\n", prob_m, prob_n, prob_k); + // 1 int4 pointer = 4 x 32 bit + // B: 32 bit packed, 4 + + // A: 16 bit, 8 + // C: 16 bit, 8 + // s: 16 bit, 8 + // meta: 16 bit, 8 + // locks: 32 bit + printf( + "thread_m_blocks: %d, thread_n_blocks: %d, thread_k_blocks: %d, " + "group_blocks: %d\n", + thread_m_blocks, thread_n_blocks, thread_k_blocks, group_blocks); + for (int batch_idx = 0; batch_idx < prob_m; batch_idx++) { + const int4 *__restrict__ A_ptr = A + batch_idx * prob_k / 8; + const int4 *__restrict__ B_ptr = B + batch_idx * prob_k * prob_n / 16 / 4; + const int4 *__restrict__ meta_ptr = + meta + batch_idx * prob_k * prob_n / 16 / 8; + + const int4 *__restrict__ s_ptr = s + batch_idx * prob_n / 8; + int4 *__restrict__ C_ptr = C + batch_idx * prob_n / 8; + int *locks_ptr = locks + batch_idx * prob_k; + const int possible_thread_m_blocks = 1; + + marlin::Marlin_2_4 + <<>>( + A_ptr, B_ptr, meta_ptr, C_ptr, s_ptr, 1, prob_n, prob_k, locks_ptr); + } +}; + +#define CALL_IF_BMM_2_4(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS) { \ + cudaFuncSetAttribute(BMM_2_4, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM); \ + BMM_2_4<<>>( \ + A_ptr, B_ptr, meta_ptr, C_ptr, s_ptr, stream, blocks, prob_n, prob_m, \ + prob_k, locks, max_par); \ + } + +#define Set_Max_SharedMemory(THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS) \ + cudaFuncSetAttribute( \ + marlin::Marlin_2_4, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, SHARED_MEM); + +int triteia_cuda_bmm_2_4(const void *A, const void *B, const void *meta, + void *C, void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int groupsize = -1, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16) { + int tot_n = prob_n; + int tot_n_blocks = marlin::ceildiv(tot_n, 16); + int pad = 16 * tot_n_blocks - tot_n; + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_m == -1) { + thread_m = 128; + thread_k = 128; + } + int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction + int thread_m_blocks = thread_m / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_m % thread_m != 0 || prob_k % thread_k != 0 || + (group_blocks != -1 && (prob_k / 2) % group_blocks != 0)) + + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) return 0; + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + const int4 *meta_ptr = (const int4 *)meta; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; + + int cols = prob_m / thread_m; + int *locks = (int *)workspace; + int ret = 0; + printf("prob_m: %d, prob_n: %d, prob_k: %d\n", prob_m, prob_n, prob_k); + for (int i = 0; i < tot_n_blocks; i += 4) { + int thread_n_blocks = tot_n_blocks - i; + prob_n = tot_n - 16 * i; + int par = 1; + printf("thread_n_blocks: %d\n", thread_n_blocks); + if (thread_n_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_n_blocks - pad) / 64; + if (par > max_par) par = max_par; + prob_n = 64 * par; + i += 4 * (par - 1); + thread_n_blocks = 4; + } + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) { + } // BMxBNxBK, group + CALL_IF_BMM_2_4(8, 1, 4, -1) // e.g., 16x128x128 + CALL_IF_BMM_2_4(8, 2, 4, -1) + CALL_IF_BMM_2_4(8, 4, 4, -1) // e.g., 16x128x128 + CALL_IF_BMM_2_4(16, 1, 2, -1) // e.g., 16x256x64 + CALL_IF_BMM_2_4(16, 2, 2, -1) // e.g.. 32x256x64 + CALL_IF_BMM_2_4(16, 3, 2, -1) + CALL_IF_BMM_2_4(16, 4, 2, -1) + CALL_IF_BMM_2_4(32, 1, 1, -1) // e.g., 16x256x64 + CALL_IF_BMM_2_4(32, 2, 1, -1) // e.g.. 32x256x64 + CALL_IF_BMM_2_4(32, 3, 1, -1) + CALL_IF_BMM_2_4(32, 4, 1, -1) + else ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_n_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_n_blocks * (prob_m / 8) * par; + } + return ret; +} +#endif // TRITEIA_CUDA_KERNEL_CUH_ +} // namespace triteia diff --git a/triteia/csrc/ops/triteia_ops.cpp b/triteia/csrc/ops/triteia_ops.cpp new file mode 100644 index 0000000..20f78a8 --- /dev/null +++ b/triteia/csrc/ops/triteia_ops.cpp @@ -0,0 +1,55 @@ +#include +#include +#include +#include + +namespace triteia { +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; +int triteia_cuda_bmm_2_4(const void *A, const void *B, const void *meta, void *C, + void *s, int prob_m, int prob_n, int prob_k, + void *workspace, int groupsize = -1, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, + int thread_m = -1, int sms = -1, int max_par = 16); + + + +void bmm_2_4(const torch::Tensor &A, const torch::Tensor &B, + const torch::Tensor &meta, torch::Tensor &C, + const torch::Tensor &s, torch::Tensor &workspace, + int thread_k = -1, int thread_m = -1, int sms = -1, + int max_par = 8) { + /** + * A: [n, k]: n: #reqs, k: in features + * B: [n, k/16, 2*m]: n: #reqs, k: in features, m: out features + * C: [n, m]: n: #reqs, m: out features + * s: [n, 1, m]: n: #reqs, m: out features + * meta: [n, k, m/16]: n: #reqs, k: in features, m: out features + */ + + int prob_n = A.size(0); + int prob_m = C.size(1); + int prob_k = A.size(1); + + int groupsize = (s.size(1) == 1) ? -1 : prob_k / s.size(1); + if (groupsize != -1 && groupsize * s.size(1) != prob_k) + AT_ERROR("k=", prob_k, " not compatible with ", s.size(0), " groups."); + if (workspace.numel() < prob_n / 128 * max_par) + AT_ERROR("workspace must be of size at least ", prob_n / 128 * max_par, + "."); + int dev = A.get_device(); + int err = triteia_cuda_bmm_2_4( + A.data_ptr(), B.data_ptr(), meta.data_ptr(), C.data_ptr(), s.data_ptr(), + prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); + if (err == ERR_PROB_SHAPE) { + AT_ERROR("Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", + " not compatible with thread_k=", thread_k, + ", thread_m=", thread_m, "."); + } else if (err == ERR_KERN_SHAPE) { + AT_ERROR("No kernel implementation for thread_k=", thread_k, + ", thread_m=", thread_m, ", groupsize=", groupsize, "."); + } +} + +} // namespace triteia \ No newline at end of file diff --git a/triteia/python/capi/__init__.py b/triteia/python/capi/__init__.py index 1f80055..8556f0e 100644 --- a/triteia/python/capi/__init__.py +++ b/triteia/python/capi/__init__.py @@ -1,3 +1,3 @@ -from .marlin import marlin_mul_2_4 +from .marlin import mul_2_4, bmm_2_4 -__all__ = ["marlin_mul_2_4"] +__all__ = ["mul_2_4", "bmm_2_4"] diff --git a/triteia/python/capi/marlin.py b/triteia/python/capi/marlin.py index 89b5ef5..114fba1 100644 --- a/triteia/python/capi/marlin.py +++ b/triteia/python/capi/marlin.py @@ -1,7 +1,7 @@ -import marlin_cuda +import triteia_cuda -def marlin_mul_2_4( +def mul_2_4( A, B, meta, C, s, workspace, thread_k=-1, thread_m=-1, sms=-1, max_par=16 ): """Marlin FP16x(INT4+2:4 sparsity) multiply; can be used within `torch.compile`. @@ -19,4 +19,14 @@ def marlin_mul_2_4( max_par: maximum number of batch 64 problems to solve in parallel for large input sizes ---- """ - marlin_cuda.mul_2_4(A, B, meta, C, s, workspace, thread_k, thread_m, sms, max_par) + triteia_cuda.mul_2_4(A, B, meta, C, s, workspace, thread_k, thread_m, sms, max_par) + +def bmm_2_4( + A, B, meta, C, s, workspace, thread_k=-1, thread_m=-1, sms=-1, max_par=16 +): + """FP16x(INT4+2:4 sparsity) batched matrix multiplication; can be used within `torch.compile`. + ---- + Parameters: + + """ + triteia_cuda.bmm_2_4(A, B, meta, C, s, workspace, thread_k, thread_m, sms, max_par) \ No newline at end of file diff --git a/triteia/python/configs/gpus/specs.py b/triteia/python/configs/gpus/specs.py new file mode 100644 index 0000000..c752068 --- /dev/null +++ b/triteia/python/configs/gpus/specs.py @@ -0,0 +1,10 @@ +nvidia_rtx_3090 = { + 'name': 'NVIDIA RTX 3090', + 'compute_capability': '8.6', + 'memory': 24, # in GB + 'bandwidth': 936.2, + 'fp16_tflops': 35.58, + 'fp32_tflops': 35.58, +} + +nvidia_gpus = [nvidia_rtx_3090] \ No newline at end of file diff --git a/triteia/python/nn/linear.py b/triteia/python/nn/linear.py index 8927356..4a38041 100644 --- a/triteia/python/nn/linear.py +++ b/triteia/python/nn/linear.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import numpy as np -from triteia.python.capi import marlin_mul_2_4 +from triteia.python.capi import mul_2_4 from triteia.python.ops.utils.sparsity import ( _perm_2_4, _scale_perm_2_4, @@ -65,7 +65,7 @@ def forward(self, x): x.shape[:-1] + (self.outfeatures,), dtype=x.dtype, device=x.device ) self.workspace = self.workspace.to(x.device) - marlin_mul_2_4( + mul_2_4( x.view((-1, x.shape[-1])), self.qweight, self.meta, diff --git a/triteia/python/ops/__init__.py b/triteia/python/ops/__init__.py index cdef958..5bab041 100644 --- a/triteia/python/ops/__init__.py +++ b/triteia/python/ops/__init__.py @@ -1,6 +1,13 @@ from .matmul.sparse_low_precision import matmul_4bit_2_4 +from .matmul.bmm import bmm_4bit_2_4_forloop from .utils.sparsity import mask_creator -from .utils.generator import gen_quant4_NT +from .utils.generator import gen_sparse_quant4_NT, gen_batched_sparse_quant4_NT -__all__ = ["matmul_4bit_2_4", "mask_creator", "gen_quant4_NT"] +__all__ = [ + "matmul_4bit_2_4", + "bmm_4bit_2_4_forloop" + "mask_creator", + "gen_sparse_quant4_NT", + "gen_batched_sparse_quant4_NT" +] diff --git a/triteia/python/ops/matmul/bmm.py b/triteia/python/ops/matmul/bmm.py new file mode 100644 index 0000000..0ccad71 --- /dev/null +++ b/triteia/python/ops/matmul/bmm.py @@ -0,0 +1,17 @@ +import torch +from .sparse_low_precision import matmul_4bit_2_4 + +def bmm_4bit_2_4_forloop(qweights, xs, metas, ss): + """ + Batched Low precision sparse matrix multiplcation with 2:4 sparsity. + ---- + Parameters: + + """ + outputs = torch.zeros( + (xs.shape[0],xs.shape[1],ss.shape[2]), dtype=xs.dtype, device=xs.device + ) + for id in range(xs.shape[0]): + outputs[id] = matmul_4bit_2_4(qweights[id], xs[id], metas[id], ss[id]) + return outputs + \ No newline at end of file diff --git a/triteia/python/ops/matmul/sparse_low_precision.py b/triteia/python/ops/matmul/sparse_low_precision.py index 1b8bfdf..0db06a7 100644 --- a/triteia/python/ops/matmul/sparse_low_precision.py +++ b/triteia/python/ops/matmul/sparse_low_precision.py @@ -1,6 +1,5 @@ import torch -from triteia.python.capi import marlin_mul_2_4 - +from triteia.python.capi import mul_2_4 def matmul_4bit_2_4(qweight, x, meta, s): """Low precision sparse matrix multiplication with 2:4 sparsity. @@ -13,7 +12,7 @@ def matmul_4bit_2_4(qweight, x, meta, s): """ C = torch.zeros((x.shape[:-1] + (s.shape[1],)), dtype=x.dtype, device=x.device) workspace = torch.zeros(s.shape[1], dtype=torch.int32, device=x.device) - marlin_mul_2_4( + mul_2_4( x, qweight, meta, diff --git a/triteia/python/ops/utils/generator.py b/triteia/python/ops/utils/generator.py index 50dcfe2..3cb600d 100644 --- a/triteia/python/ops/utils/generator.py +++ b/triteia/python/ops/utils/generator.py @@ -3,13 +3,11 @@ from .sparsity import mask_creator -def gen_quant4_NT(m, k, groupsize=-1, device="cuda", prune_n=2, prune_m=4): +def gen_sparse_quant4_NT(m, k, groupsize=-1, device="cuda", prune_n=2, prune_m=4): from triteia.python.nn.linear import sparse_low_precision_linear - maxq = 2**4 - 1 w = torch.randn((m, k), dtype=torch.half, device=device) k_sp = k // 2 - w = w.t() if groupsize != -1: w = w.reshape((-1, groupsize, m)) @@ -22,23 +20,19 @@ def gen_quant4_NT(m, k, groupsize=-1, device="cuda", prune_n=2, prune_m=4): w = torch.clamp(w, 0, maxq) ref = (w - (maxq + 1) // 2).half() * s if groupsize != -1: - def reshape(w): w = w.reshape((groupsize, -1, m)) w = w.permute(1, 0, 2) w = w.reshape((k, m)).contiguous() return w - ref = reshape(ref) w = reshape(w) mask = mask_creator(w.T, n=prune_n, m=prune_m).cuda().bool() uncompress = (mask * ref.T).T - s = s.reshape((-1, m)).contiguous() linear = nn.Linear(k, m) linear.weight.data = ref - - layer = sparse_low_precision_linear(256, 256, groupsize=groupsize) + layer = sparse_low_precision_linear(m, k, groupsize=groupsize) if groupsize == -1: groupsize = k layer.k = k @@ -53,5 +47,21 @@ def reshape(w): q = layer.B s = layer.s meta = layer.meta - return uncompress, q, s, meta + +def gen_batched_sparse_quant4_NT(b, m, n, groupsize=-1, device="cuda:0"): + metas = [] + qs = [] + scales = [] + uncompressed = [] + for i in range(b): + unc, q, s, meta = gen_sparse_quant4_NT(m, n, groupsize=groupsize, device=device) + uncompressed.append(unc.t()) + qs.append(q) + scales.append(s) + metas.append(meta) + uncompressed= torch.stack(uncompressed).to(device) + qs = torch.stack(qs).to(device) + scales = torch.stack(scales).to(device) + metas = torch.stack(metas).to(device) + return uncompressed, qs, scales, metas \ No newline at end of file diff --git a/triteia/python/utils/__init__.py b/triteia/python/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/triteia/python/utils/benchmark.py b/triteia/python/utils/benchmark.py new file mode 100644 index 0000000..e69de29