diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index 6ab709b09..5ca6b4e34 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -37,6 +37,7 @@ struct ParamsBase { , gamma(nullptr) , workspace(nullptr) , barrier(nullptr) + , is_rms_only(false) { } @@ -59,6 +60,9 @@ struct ParamsBase { // Multi-CTA sync barriers in gmem. int *barrier; + //Indicates whether it is RMSnorm or not + bool is_rms_only; + }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 30e4a5fec..60103d4ec 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -239,8 +239,171 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size //////////////////////////////////////////////////////////////////////////////////////////////////// +std::vector rmsnorm_fwd(const at::Tensor &x, // BxSxhidden_size + const at::Tensor &gamma, // hidden_size + const float epsilon +) { + auto itype = x.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = wtype; + auto ctype = torch::kFloat32; + + // TORCH_CHECK(beta.scalar_type() == wtype); + + TORCH_CHECK(x.is_cuda()) + TORCH_CHECK(gamma.is_cuda()) + // TORCH_CHECK(beta.is_cuda()) + + TORCH_CHECK(x.is_contiguous()); + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); + + const int rows = sizes[0]; + const int cols = sizes[1]; + auto hidden_size = gamma.numel(); + + // TORCH_CHECK(gamma.sizes() == beta.sizes()); + TORCH_CHECK(hidden_size == cols); + + TORCH_CHECK(epsilon >= 0.f); + + auto opts = x.options(); + + auto z = torch::empty(sizes, opts.dtype(otype)); + + // auto mu = torch::empty({ rows }, opts.dtype(ctype)); + auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); + + layer_norm::LaunchParams launch_params; + + launch_params.props = at::cuda::getCurrentDeviceProperties(); + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + + // Request the kernel launcher. + auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size); + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + at::Tensor workspace, barrier; + + // Set the kernel runtime parameters. + layer_norm::FwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data_ptr(); + // params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.z = z.data_ptr(); + params.epsilon = epsilon; + params.is_rms_only = true; + + if( launch_params.barrier_size > 0 ) { + auto options = x.options(); + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + // Launch the kernel. + launcher(launch_params, false); + + return { z, rsigma }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector rmsnorm_bwd(const at::Tensor &dz, // BxSxhidden_size + const at::Tensor &x, // BxSxhidden_size + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma // hidden_size +) { + + auto itype = x.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = wtype; + auto ctype = torch::kFloat32; + + TORCH_CHECK(dz.dtype() == otype); + // TORCH_CHECK(mu.dtype() == ctype); + TORCH_CHECK(rsigma.dtype() == ctype); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(dz.is_cuda()); + // TORCH_CHECK(mu.is_cuda()); + TORCH_CHECK(rsigma.is_cuda()); + TORCH_CHECK(gamma.is_cuda()); + + TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(dz.is_contiguous()); + + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); + TORCH_CHECK(dz.sizes() == sizes); + auto rows = sizes[0]; + auto cols = sizes[1]; + + auto hidden_size = gamma.numel(); + + // TORCH_CHECK(mu.numel() == rows); + // TORCH_CHECK(mu.sizes() == rsigma.sizes()); + + TORCH_CHECK(gamma.numel() == cols); + + auto options = x.options(); + + auto dx = torch::empty_like(x); + auto dgamma = torch::empty_like(gamma); + // auto dbeta = torch::empty_like(gamma); + + layer_norm::LaunchParams launch_params; + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + + auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size); + + launcher(launch_params, true); + + auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype)); + // auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype)); + at::Tensor workspace, barrier; + + layer_norm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data_ptr(); + // params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.dz = dz.data_ptr(); + params.dx = dx.data_ptr(); + // params.dbeta = dbeta.data_ptr(); + params.dgamma = dgamma.data_ptr(); + // params.dbeta_part = dbeta_part.data_ptr(); + params.dgamma_part = dgamma_part.data_ptr(); + params.is_rms_only = true; + + if( launch_params.barrier_size > 0 ) { + // TODO Any way to avoid this? + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + launcher(launch_params, false); + + return { dx, dgamma, dgamma_part }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "CUDA LayerNorm"; + m.doc() = "CUDA LayerNorm & RMSNorm"; m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel"); m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "Run RMSNorm forward kernel"); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "Run RMSNorm backward kernel"); } diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 8595f5ed4..7932dbb26 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -189,6 +189,162 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { } } +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void rmsnorm_bwd_kernel(layer_norm::BwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Reducer = typename Ktraits::Reducer_single; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + Cvec dzy_sum[LDGS]; + + memset(dzy_sum, 0, sizeof(dzy_sum)); + + compute_t * smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + constexpr float rn = 1.f / float(COLS); + Wvec gamma[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + gamma[it].load_from(params.gamma, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the + // last blocks with syncthreads! + // grid stride over rows + #pragma unroll 1 + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t rs_r = static_cast(params.rs)[row]; + Ivec x[LDGS]; + Ovec dz[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz[it].load_from(params.dz, idx); + x[it].load_from(params.x, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; + + compute_t mdyy_local = 0.f; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t x_tmp = x[it].data.elt[jt]; + compute_t y_tmp = rs_r * x_tmp; + compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); + dy_tmp *= compute_t(dz[it].data.elt[jt]); + compute_t dz_tmp = dz[it].data.elt[jt]; + + mdyy_local += dy_tmp * y_tmp; + + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; + + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + } + } + + auto result = reducer.allreduce(mdyy_local, sum); + mdyy_local = result * rn; + + Ivec dx[LDGS]; + idx = row * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - mdyy_local * y_tmp); + dx[it].data.elt[jt] = dx_tmp; + } + dx[it].store_to(params.dx, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + + } // end: grid stride loop + + if( WARPS_M == 1 ) { + idx = r * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dzy_sum[it].store_to(params.dgamma_part, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps + + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dzy_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dzy_sum[NUM_RES]; + memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + + compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; + for( int jt = 0; jt < NUM_RES; jt++ ) { + *dgamma_part = cta_dzy_sum[jt]; + dgamma_part += Ktraits::THREADS_PER_CTA; + } + } +} + template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_kernel(BwdParams params) @@ -312,4 +468,128 @@ void ln_bwd_finalize_kernel(BwdParams params) } } } + +template +__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) +void rmsnorm_bwd_finalize_kernel(BwdParams params) +{ + + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + __shared__ char smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { + // Each thread sums over NUM_ELT columns. + Vec dgamma_local; + memset(&dgamma_local, 0, sizeof(dgamma_local)); + // memset(&dbeta_local, 0, sizeof(dbeta_local)); + for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { + index_t idx = row * Kernel_traits::COLS + col; + + Vec dgamma_part; + // dbeta_part.load_from(params.dbeta_part, idx); + dgamma_part.load_from(params.dgamma_part, idx); + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; + // dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; + } + } + + void * smem_gamma = smem_; + // void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; + + dgamma_local.store_to(smem_gamma, write_idx); + // dbeta_local.store_to(smem_beta, write_idx); + + __syncthreads(); + + // It would be probably safe to reuse the first row of smem_beta and smem_gamma + void * smem_gamma_out = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + // void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + + + // More than one iter iff ROWS_PER_CTA < 32. + for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; + + // memset(&dbeta_local, 0, sizeof(dbeta_local)); + memset(&dgamma_local, 0, sizeof(dgamma_local)); + + // Load beta and gamma transposed + if(read_row < Kernel_traits::ROWS_PER_CTA){ + // dbeta_local.load_from(smem_beta, read_idx); + dgamma_local.load_from(smem_gamma, read_idx); + } + + // Call reducer on the loaded value(s) and convert. + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + // compute_t b_i = dbeta_local.data.elt[it]; + compute_t g_i = dgamma_local.data.elt[it]; + // b_i = reducer.allreduce(b_i, sum); + g_i = reducer.allreduce(g_i, sum); + + dgamma_local.data.elt[it] = g_i; + // dbeta_local.data.elt[it] = b_i; + } + + // Leader stores the result at the current column. + if(lane == 0){ + dgamma_local.store_to(smem_gamma_out, w); + // dbeta_local.store_to(smem_beta_out, w); + } + + } + + // All writes done. + __syncthreads(); + + // Pack and store: 2-wide stores with half the threads. + if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { + + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dgamma_vec2; + Vec dgamma_out2; + + dgamma_vec2.load_from(smem_gamma_out, lane); + // dbeta_vec2.load_from(smem_beta_out, lane); + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); + // dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); + } + dgamma_out2.store_to(params.dgamma, col_out); + // dbeta_out2.store_to(params.dbeta, col_out); + + } + } +} } // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu index 3893d4e0c..84aa88f16 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -31,7 +31,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params WARPS_N, BYTES_PER_LDG_MAIN >; - auto kernel = &ln_bwd_kernel; + auto kernel = launch_params.params.is_rms_only ? &rmsnorm_bwd_kernel : &ln_bwd_kernel; if( configure_params ) { int ctas_per_sm; @@ -45,7 +45,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) + * (launch_params.params.is_rms_only ? sizeof(compute_t) : sizeof(typename Kernel_traits::reduce_t)) * 2; } return; @@ -75,7 +75,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params 32 * 32, // THREADS_PER_CTA BYTES_PER_LDG_FINAL>; - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + auto kernel_f = launch_params.params.is_rms_only ? &rmsnorm_bwd_finalize_kernel : &layer_norm::ln_bwd_finalize_kernel; kernel_f<<>>(launch_params.params); } diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu index dc4e89cf5..ce4410945 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -30,7 +30,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params WARPS_N, BYTES_PER_LDG >; - auto kernel = &ln_fwd_kernel; + auto kernel = launch_params.params.is_rms_only ? &rmsnorm_fwd_kernel : &ln_fwd_kernel; if( configure_params ) { int ctas_per_sm; @@ -44,7 +44,7 @@ void launch_(LaunchParams &launch_params, const bool configure_params launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) + * (launch_params.params.is_rms_only ? sizeof(compute_t) : sizeof(typename Kernel_traits::Stats::stats_t)) * 2; } return; diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh index 64e72974f..b8dab9f67 100644 --- a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh @@ -107,4 +107,98 @@ void ln_fwd_kernel(FwdParams params) { } } +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void rmsnorm_fwd_kernel(FwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + + using Reducer = typename Ktraits::Reducer_single; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *rs_ptr = static_cast(params.rs); + + Wvec gamma[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + gamma[it].load_from(params.gamma, idx); + idx += VEC_COLS_PER_LDG; + } + + constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + Ivec x[LDGS]; + index_t idx = row * Ktraits::VEC_COLS + c; + compute_t xf[LDGS * NUM_ELTS]; + compute_t m = Zeros::get(); + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + x[it].load_from(params.x, idx); + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t x_ij = compute_t(x[it].data.elt[jt]); + xf[it * NUM_ELTS + jt] = x_ij; + m += x_ij * x_ij * rn; + } + idx += VEC_COLS_PER_LDG; + } + + auto sum = Sum(); + m = reducer.allreduce(m, sum); + + compute_t rs = rsqrtf(m + params.epsilon); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + rs_ptr[row] = rs; + } + + Ovec z[LDGS]; + idx = row * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + output_t y_ij = output_t(rs * xf[it * NUM_ELTS + jt]); + output_t g_ij = gamma[it].data.elt[jt]; + z[it].data.elt[jt] = g_ij * y_ij; + } + z[it].store_to(params.z, idx); + idx += VEC_COLS_PER_LDG; + } + + } +} + } // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h index ed745c5ee..494b9aeeb 100644 --- a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h +++ b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h @@ -128,6 +128,8 @@ struct Kernel_traits : public Base { using reduce_t = typename layer_norm::TypeToVec2::Type; using Reducer = layer_norm::Reducer; + using Reducer_single = layer_norm::Reducer; + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; diff --git a/apex/contrib/layer_norm/__init__.py b/apex/contrib/layer_norm/__init__.py index 4bbc4763a..c24020e2e 100644 --- a/apex/contrib/layer_norm/__init__.py +++ b/apex/contrib/layer_norm/__init__.py @@ -1 +1 @@ -from .layer_norm import FastLayerNorm +from .layer_norm import FastLayerNorm, FastRMSNorm diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index b084b1ace..724c85138 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -51,3 +51,47 @@ def reset_parameters(self): def forward(self, x): return _fast_layer_norm(x, self.weight, self.bias, self.epsilon) + +class FastRMSNormFN(torch.autograd.Function): + @staticmethod + def forward(ctx, x, gamma, epsilon): + x = x.contiguous() + gamma = gamma.contiguous() + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + ymat, rsigma = fast_layer_norm.rmsnorm_fwd(xmat, gamma, epsilon) + ctx.save_for_backward(x, gamma, rsigma) + return ymat.view(x.shape) + + @staticmethod + def backward(ctx, dy): + # assert dy.is_contiguous() + dy = dy.contiguous() # this happens! + x, gamma, rsigma = ctx.saved_tensors + + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dymat = dy.view(xmat.shape) + dxmat, dgamma, _ = fast_layer_norm.rmsnorm_bwd(dymat, xmat, rsigma, gamma) + dx = dxmat.view(x.shape) + return dx, dgamma, None + + +def _fast_rms_norm(x, weight, epsilon): + args = _cast_if_autocast_enabled(x, weight, epsilon) + with torch.cuda.amp.autocast(enabled=False): + return FastRMSNormFN.apply(*args) + + +class FastRMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5): + super().__init__() + self.epsilon = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x): + return _fast_rms_norm(x, self.weight, self.epsilon) diff --git a/apex/contrib/test/layer_norm/test_fast_rms_norm.py b/apex/contrib/test/layer_norm/test_fast_rms_norm.py new file mode 100644 index 000000000..86692c33e --- /dev/null +++ b/apex/contrib/test/layer_norm/test_fast_rms_norm.py @@ -0,0 +1,265 @@ +import unittest + +import torch + +SKIP_TEST = None +try: + from apex.contrib.layer_norm.layer_norm import FastRMSNorm + import fast_layer_norm as fln +except ImportError as e: + SKIP_TEST = e + + +class GPUTimer: + def __init__(self, stream): + self.start_ = torch.cuda.Event(enable_timing=True) + self.stop_ = torch.cuda.Event(enable_timing=True) + self.stream_ = stream + + def start(self): + self.stream_.record_event(self.start_) + + def stop(self): + self.stream_.record_event(self.stop_) + + def sync(self): + self.stream_.synchronize() + + def millis(self): + return self.start_.elapsed_time(self.stop_) + + +def size_in_bytes(t): + return torch.numel(t) * t.element_size() + + +def metrics(y_ref, y, epsilon=1e-6): + y_ref = y_ref.float() + y = y.float() + relerr, mse = ( + (y_ref - y).abs().sum() / (y_ref.abs().sum() + epsilon), + (y_ref - y).square().mean(), + ) + return relerr.item(), mse.item() + + +device = torch.device("cuda") +fp32 = torch.float32 +fp16 = torch.float16 +bf16 = torch.bfloat16 + + +def backward_(dz, x, rs, gamma): + + wtype = gamma.dtype + itype = x.dtype + otype = dz.dtype + ctype = rs.dtype + rs = rs.unsqueeze(1) + + hidden_size = gamma.numel() + y = rs * x.to(ctype) + dgamma = (dz * y).view(-1, hidden_size).sum(0, dtype=ctype) + dy = dz.view(-1, hidden_size).to(ctype) * gamma.unsqueeze(0).to(ctype) + + mdyy = (dy * y).mean(1, keepdim=True, dtype=ctype) + dx = rs * (dy - mdyy * y) + + return dx.to(itype), dgamma.to(wtype) + + +def benchmark_(S, B, hidden_size, itype, wtype, runs=100): + epsilon = 1e-5 + + x = torch.randn((S * B, hidden_size), dtype=itype, device=device) + # beta = torch.randn(hidden_size, dtype=wtype, device=device) + gamma = torch.randn(hidden_size, dtype=wtype, device=device) + dz = torch.randn(x.shape, dtype=wtype, device=device) + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + + timer = GPUTimer(stream) + + # warmup + for r in range(runs): + z, rsigma = fln.rmsnorm_fwd(x, gamma, epsilon) + + timer.start() + for r in range(runs): + z, rsigma = fln.rmsnorm_fwd(x, gamma, epsilon) + timer.stop() + timer.sync() + + total_bytes_fwd = sum([size_in_bytes(t) for t in [x, z, gamma, rsigma]]) + + ms_fwd = timer.millis() / runs + + print( + "[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format( + ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd + ) + ) + + timer.start() + for r in range(runs): + dx, dgamma, dgp = fln.rmsnorm_bwd(dz, x, rsigma, gamma) + timer.stop() + timer.sync() + + total_bytes_bwd = sum( + [ + size_in_bytes(t) + for t in [dz, x, rsigma, gamma, dx, dgamma, dgp, dgp] + ] + ) + + ms_bwd = timer.millis() / runs + + print( + "[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format( + ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd + ) + ) + +def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): + + seed = 1243 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + otype = wtype + print("========================================================") + print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}") + print("--------------------------------------------------------") + + x = torch.randn(S * B, hidden_size, dtype=itype, device=device) + gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 + epsilon = 1e-5 + + x.requires_grad = True + gamma.requires_grad = True + + v = torch.square(x.to(ctype)).mean(1, dtype=ctype, keepdim=True) + rs_ref = torch.rsqrt(v + epsilon) + y_ref = rs_ref * x.to(ctype) + z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype)).to(otype) + + rs_ref = rs_ref.flatten() + + dz = torch.randn_like(z_ref) + + # z_ref.backward(dz) + # dx_ref = x.grad + # dgamma_ref = gamma.grad + # dbeta_ref = beta.grad + + dx_ref, dg_ref = backward_(dz, x, rs_ref, gamma) + + z, rs = fln.rmsnorm_fwd(x, gamma, epsilon) + dx, dg, dg_part = fln.rmsnorm_bwd(dz, x, rs, gamma) + + re_z, mse_z = metrics(z_ref, z) + re_rs, mse_rs = metrics(rs_ref, rs) + + re_dx, mse_dx = metrics(dx_ref, dx) + re_dg, mse_dg = metrics(dg_ref, dg) + + print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}") + print(f"rs: relerr={re_rs:.4e} mse={mse_rs:.4e}") + + print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}") + print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}") + + def check_err(x, relerr): + tol = 1e-3 if x.dtype == torch.float16 else 5e-6 + return relerr < tol + + return [ + check_err(x, re) + for x, re in zip([z, rs, dx, dg], [re_z, re_rs, re_dx, re_dg]) + ] + +@unittest.skipIf(SKIP_TEST, f"{SKIP_TEST}") +class TestFastRMSNorm(unittest.TestCase): + # TODO(crcrpar): Try `torch.testing.assert_close` instead and migrate to it if it's working. + def assertAll(self, l): + if not all(l): + print(l) + for x in l: + self.assertTrue(x) + + def test_all_configs(self): + + hidden_sizes = [ + 768, + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 14336, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, + ] + + for h in hidden_sizes: + with self.subTest(f"hidden_size={h}"): + self.assertAll(_test_impl(256, 2, h, fp32, fp32)) + self.assertAll(_test_impl(256, 2, h, fp16, fp16)) + self.assertAll(_test_impl(256, 2, h, fp32, fp16)) + self.assertAll(_test_impl(256, 2, h, bf16, bf16)) + self.assertAll(_test_impl(256, 2, h, fp32, bf16)) + + def test_run_benchmark(self): + for (S, B, hidden_size, runs) in ( + (512, 32, 768, 1000), + (512, 32, 1024, 1000), + (512, 8, 4096, 1000), + (512, 8, 5120, 1000), + (512, 8, 6144, 1000), + (256, 2, 20480, 500), + (256, 2, 25600, 500), + (256, 2, 40960, 250), + (256, 2, 65536, 250), + ): + with self.subTest(f"(S, B, hidden_size)=({S}, {B}, {hidden_size})"): + benchmark_(S, B, hidden_size, fp16, fp16, runs) + + def test_compat_with_autocast(self): + autocast_dtypes = ( + (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) + ) + input_shape = (512, 32, 768) + rms_norm = FastRMSNorm(input_shape[-1]).cuda() + input = torch.randn(input_shape).cuda() + + for dtype in autocast_dtypes: + rms_norm.zero_grad(set_to_none=True) + with self.subTest(f"autocast_dtype={dtype}"): + with torch.cuda.amp.autocast(enabled=True, dtype=dtype): + out = rms_norm(input) + self.assertEqual(dtype, out.dtype) + grad = torch.randn_like(out) + out.backward(grad) + self.assertEqual(torch.float32, rms_norm.weight.grad.dtype) + +if __name__ == '__main__': + unittest.main()