diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu index 5e051cd73b..9c24093944 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu @@ -138,20 +138,20 @@ __global__ void set_dynamic_kernel_args_kernel( GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>( problem_shape_buf); // Pass dummy configs to get Stride structure - GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputA* stride_input_A_ptr = reinterpret_cast< GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputA*>(stride_buf); - GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputB* stride_input_B_ptr = reinterpret_cast< GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputB*>(stride_buf + stride_size); - GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideOutput* stride_output_ptr = reinterpret_cast< GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideOutput*>(stride_buf + (stride_size * 2)); output_args_ptr[group_index] = @@ -167,15 +167,15 @@ __global__ void set_dynamic_kernel_args_kernel( zero_start_index_M[group_index], N, K); stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride( typename GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{}, + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{}, {zero_start_index_M[group_index], K, 1}); stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride( typename GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{}, + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{}, {N, K, 1}); stride_output_ptr[group_index] = cutlass::make_cute_packed_stride( typename GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{}, + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{}, {zero_start_index_M[group_index], N, 1}); } } @@ -212,20 +212,20 @@ __global__ void set_static_kernel_args_kernel( GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape*>( problem_shape_buf); // Pass dummy configs to get Stride structure - GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputA* stride_input_A_ptr = reinterpret_cast< GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputA*>(stride_buf); - GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputB* stride_input_B_ptr = reinterpret_cast< GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideInputB*>(stride_buf + stride_size); - GroupedGemmBF16Args::GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmBF16Args::GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideOutput* stride_output_ptr = reinterpret_cast< GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>:: + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>:: StrideOutput*>(stride_buf + (stride_size * 2)); output_args_ptr[group_index] = reinterpret_cast(output_data); @@ -237,15 +237,15 @@ __global__ void set_static_kernel_args_kernel( GroupedGemmBF16Args::ProblemShape::UnderlyingProblemShape(M, N, K); stride_input_A_ptr[group_index] = cutlass::make_cute_packed_stride( typename GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputA{}, + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputA{}, {M, K, 1}); stride_input_B_ptr[group_index] = cutlass::make_cute_packed_stride( typename GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideInputB{}, + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideInputB{}, {N, K, 1}); stride_output_ptr[group_index] = cutlass::make_cute_packed_stride( typename GroupedGemmBF16Args:: - GroupedGemmConfigs<128, 128, 128, 2, 1, 1, true>::StrideOutput{}, + GroupedGemmConfigs<128, 256, 64, 2, 1, 1, false>::StrideOutput{}, {M, N, 1}); } } @@ -470,10 +470,10 @@ std::vector dispatch_bf16_grouped_kernel( return bf16bf16bf16_grouped_impl<64, 128, 128, 2, 1, 1, true>( x_group, w_group, output_tensor, zero_start_index_M); } else if (kernel == KernelMode::Large) { - return bf16bf16bf16_grouped_impl<128, 128, 128, 2, 1, 1, true>( + return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>( x_group, w_group, output_tensor, zero_start_index_M); } else { - return bf16bf16bf16_grouped_impl<128, 128, 128, 1, 2, 1, true>( + return bf16bf16bf16_grouped_impl<128, 256, 64, 2, 1, 1, false>( x_group, w_group, output_tensor, zero_start_index_M); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu index e9726fbf43..abceb2f984 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise.cu @@ -98,13 +98,18 @@ at::Tensor bf16i4bf16_rowwise_impl( cute::Int>; // Shape of the // threadblocks in a // cluster - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using CooperativeSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using CooperativeEpilogueSchedule = + cutlass::epilogue::TmaWarpSpecializedCooperative; + using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using MainLoopSchedule = - cute::conditional_t; + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -231,18 +236,18 @@ at::Tensor dispatch_bf16i4bf16_rowwise_kernel( } else if (kernel == KernelMode::Large) { return bf16i4bf16_rowwise_impl< 128, - 128, - 128, + 256, + 64, 2, 1, 1, - true, + false, WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); } else { return bf16i4bf16_rowwise_impl< 128, - 128, - 128, + 256, + 64, 2, 1, 1, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu index 871543a2ff..bd3e3c7f1a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16i4bf16_rowwise_batched.cu @@ -102,13 +102,18 @@ at::Tensor bf16i4bf16_rowwise_batched_impl( cute::Int>; // Shape of the // threadblocks in a // cluster - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using CooperativeSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using CooperativeEpilogueSchedule = + cutlass::epilogue::TmaWarpSpecializedCooperative; + using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using MainLoopSchedule = - cute::conditional_t; + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -235,17 +240,17 @@ at::Tensor dispatch_bf16i4bf16_rowwise_batched_kernel( } else if (kernel == KernelMode::Large) { return bf16i4bf16_rowwise_batched_impl< 128, - 128, + 256, 64, 2, 1, 1, - true, + false, WEIGHT_SCALE_DTYPE>(X, WQ, w_scale, w_zp); } else { return bf16i4bf16_rowwise_batched_impl< 128, - 128, + 256, 64, 2, 1, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu index cef942ecd8..fbb0546c97 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu @@ -92,13 +92,18 @@ at::Tensor f8i4bf16_rowwise_impl( cute::Int>; // Shape of the // threadblocks in a // cluster - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedMixedInput; + using CooperativeSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput; using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using CooperativeEpilogueSchedule = + cutlass::epilogue::TmaWarpSpecializedCooperative; + using PongEpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using MainLoopSchedule = - cute::conditional_t; + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; // Implement rowwise scaling epilogue for x using XScale = cutlass::epilogue::fusion::Sm90RowBroadcast< @@ -254,19 +259,19 @@ at::Tensor dispatch_f8i4bf16_rowwise_kernel( } else if (kernel == KernelMode::Large) { return f8i4bf16_rowwise_impl< 128, - 128, - 128, + 256, + 64, 2, 1, 1, - true, + false, InputDType, WEIGHT_SCALE_DTYPE>(XQ, WQ, x_scale, w_scale, w_zp); } else { return f8i4bf16_rowwise_impl< 128, - 128, - 128, + 256, + 64, 2, 1, 1,