diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 196ec8cfce88e..a8cf2211cdaa0 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -10,7 +10,8 @@ from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) - +from vllm import envs +from torch import nn def main(model, tp_size, gpu, dtype: str): os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) @@ -154,6 +155,15 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, device=hidden_states.device, dtype=hidden_states.dtype, ) + if envs.VLLM_MOE_PADDING: + w1 = nn.Parameter(F.pad(w1.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + w2 = nn.Parameter(F.pad(w2, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() w1_scale = None w2_scale = None diff --git a/benchmarks/kernels/benchmark_mixtral_moe_decode.py b/benchmarks/kernels/benchmark_mixtral_moe_decode.py new file mode 100644 index 0000000000000..30f2b182738bb --- /dev/null +++ b/benchmarks/kernels/benchmark_mixtral_moe_decode.py @@ -0,0 +1,255 @@ +import argparse +import json +import os +import sys + +import torch +import torch.nn.functional as F +import triton +from tqdm import tqdm +from vllm import envs +from torch import nn +from vllm.model_executor.layers.fused_moe import (fused_moe, + get_config_file_name) + + +def main(model, tp_size, gpu, dtype: str): + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) + method = fused_moe + # for bs in [ + # 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, + # 2048, 3072, 4096 + # ]: + for bs in [8, 16, 32, 64, 96, 112, 120, 128]: + run_grid(bs, + model=model, + method=method, + gpu=gpu, + tp_size=tp_size, + dtype=dtype) + + +def run_grid(bs, model, method, gpu, tp_size, dtype: str): + if model == '8x7B': + d_model = 4096 + model_intermediate_size = 14336 + num_layers = 32 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + num_layers = 56 + else: + raise ValueError(f'Unsupported Mixtral model {model}') + num_total_experts = 8 + top_k = 2 + # tp_size = 2 + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + configs = [] + + for block_size_n in [32, 64, 128, 256]: + for block_size_m in [16, 32, 64, 128, 256]: + for block_size_k in [64, 128, 256]: + for group_size_m in [1, 16, 32, 64]: + for num_warps in [4, 8]: + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) + + best_config = None + best_time_us = 1e20 + + print(f'{tp_size=} {bs=}') + + # for config in tqdm(configs): + if 1: + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=None, + dtype=dtype, + ) + except triton.runtime.autotuner.OutOfResources: + #continue + pass + + # trial + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + method=method, + config=None, + dtype=dtype, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + # best_config = config + best_time_us = kernel_dur_us + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') + + print("best_time_us", best_time_us) + print("best_config", best_config) + + # holds Dict[str, Dict[str, int]] + # filename = get_config_file_name(num_total_experts, + # model_intermediate_size // tp_size, + # "float8" if dtype == "float8" else None) + # print(f"writing config to file {filename}") + # existing_content = {} + # if os.path.exists(filename): + # with open(filename, "r") as f: + # existing_content = json.load(f) + # existing_content[str(bs)] = best_config + # with open(filename, "w") as f: + # json.dump(existing_content, f, indent=4) + # f.write("\n") + + +def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, + top_k: int, tp_size: int, model_intermediate_size: int, method, + config, dtype: str) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + + hidden_states = torch.rand( + (bs, d_model), + device="cuda:0", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + if envs.VLLM_MOE_PADDING: + w1 = nn.Parameter(F.pad(w1.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + w2 = nn.Parameter(F.pad(w2.data, + (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + + gating_output = F.softmax(torch.rand( + (num_calls, bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for i in range(num_calls): + hidden_states = method( + hidden_states=hidden_states, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + gating_output=gating_output[i], + topk=2, + renormalize=True, + inplace=True, + override_config=config, + use_fp8=dtype == "float8", + ) + end_event.record() + end_event.synchronize() + + + # torch_output = torch_moe(a, w1, w2, score, topk) + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + parser.add_argument('--model', + type=str, + default='8x7B', + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + parser.add_argument('--tp-size', + type=int, + default=2, + help='Tensor paralleli size') + parser.add_argument('--gpu', + type=int, + default=0, + help="GPU ID for benchmarking") + args = parser.parse_args() + sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype)) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 63080eaf2f11c..77fb8d3c966a0 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -17,30 +17,11 @@ def main(args): - os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID os.environ["HIP_FORCE_DEV_KERNARG"] = "1" os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - os.environ["OPTIMIZE_EPILOGUE"] = "1" for bs in [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, ]: run_grid(bs, model=args.model, TP=args.TP) @@ -49,21 +30,22 @@ def main(args): def get_full_tuning_space(): configs = [] - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] + block_m_range = [32] + block_n_range = [128] + block_k_range = [128] # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] + num_warps_range = [8] + group_m_range = [1] # For now we see better perf with num_stages=0 for all gemm configs we care # But keep this explicit so that we do not forget we may need to set it to # other values in the future num_stage_range = [0] waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] + matrix_instr_nonkdim_range = [16] + kpack_range = [2] - for block_m in block_mn_range: - for block_n in block_mn_range: + for block_m in block_m_range: + for block_n in block_n_range: for block_k in block_k_range: for num_warps in num_warps_range: for group_m in group_m_range: @@ -91,77 +73,8 @@ def get_full_tuning_space(): ## Utilize method from rocm/Triton tuning script def prune_configs(M, N, K, configs): - pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = 1 # config.get("SPLIT_K") - GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - return pruned_configs + return configs def union_of_list_of_dicts(l1, l2): @@ -195,7 +108,7 @@ def run_grid(bs, model, TP): num_calls = 100 num_warmup_trials = 1 - num_trials = 1 + num_trials = 10 full_configs = get_full_tuning_space() M1 = bs * 2 @@ -293,7 +206,7 @@ def run_timing( ) w1 = torch.rand( - (num_total_experts, 2 * shard_intermediate_size, d_model), + (num_total_experts, 2 * shard_intermediate_size, d_model+128), device=hidden_states.device, dtype=hidden_states.dtype, ) @@ -318,7 +231,7 @@ def run_timing( assert (hidden_states.shape[0] == gating_output.shape[0] ), "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - 128, "Hidden size mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -393,32 +306,40 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8=False, + use_fp8=False ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - None, # a2_scale - None, # w2_scale - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 - else tl.float16), - use_fp8=False, - ) + # ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + # invoke_fused_moe_kernel( + # intermediate_cache2, + # w2, + # intermediate_cache3, + # None, # a2_scale + # None, # w2_scale + # topk_weights, + # topk_ids, + # sorted_token_ids, + # expert_ids, + # num_tokens_post_padded, + # True, + # 1, + # config, + # compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 + # else tl.float16), + # use_fp8=False, + # ) end_event.record() end_event.synchronize() + # print(f"intermediate 0 shape = {intermediate_cache1.shape}") + # print(f"intermediate 1 shape = {intermediate_cache2.shape}") + # print(f"intermediate 2 shape = {intermediate_cache3.shape}") + # print(f"config = {config}") + # print(f"sorted token ids = {sorted_token_ids}") + # print(f"sorted token ids shape = {sorted_token_ids.shape}") + # print(f"expert ids = {expert_ids}") + # print(f"num_tokens_post_padded = {num_tokens_post_padded}") dur_ms = start_event.elapsed_time(end_event) / num_calls return dur_ms diff --git a/benchmarks/test_accuracy.py b/benchmarks/test_accuracy.py new file mode 100644 index 0000000000000..06c1150f9f20f --- /dev/null +++ b/benchmarks/test_accuracy.py @@ -0,0 +1,44 @@ +from vllm import LLM, SamplingParams +import time + + +def main(): + llm = LLM( + '/data/AI-ModelScope/Mixtral-8x7B-Instruct-v0___1/', + tensor_parallel_size=1, + #quantization="serenity", + dtype='float16', + #swap_space=16, + #enforce_eager=True, + #kv_cache_dtype="fp8", + #quantization="fp8", + #quantized_weights_path="/quantized/quark/llama.safetensors", + #worker_use_ray=True, + #trust_remote_code=True, + #distributed_executor_backend="mp", + ) + batch_size = 5 + max_tokens = 256 + prompt = """The sun is a""" + sampling_params = SamplingParams(temperature=0, + top_p=0.95, + max_tokens=max_tokens) + + start_time = time.perf_counter() + outs = llm.generate([prompt] * batch_size, sampling_params=sampling_params) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + + out_lengths = [len(x.token_ids) for out in outs for x in out.outputs] + num_tokens = sum(out_lengths) + + print( + f"{num_tokens} tokens. {num_tokens / batch_size} on average. {num_tokens / elapsed_time:.2f} tokens/s. {elapsed_time} seconds" + ) + for out in outs: + print("===========") + print(out.outputs[0].text) + + +if __name__ == "__main__": + main() diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 9e92187967d47..bf196b235178e 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -51,6 +51,61 @@ void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in, at::cuda::getCurrentCUDAStream(), CuCount); } +void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c, + void* topk_weights, + void* topk_ids, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_post_padded, + const int M, const int N, const int K, const int EM, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + cudaStream_t stream, const int CuCount); + +void wvSpltK_fsdMoe(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + at::Tensor topk_weights, + at::Tensor topk_ids, + at::Tensor sorted_token_ids, + at::Tensor expert_ids, + at::Tensor num_tokens_post_padded, + const int M, const int N, const int K, const int EM, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + const int CuCount) { + //int M = in_a.size(0); + //int K = in_a.size(1); + //int N = N_in; + wvSpltK_fsdMoe_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), + topk_weights.data_ptr(), + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + expert_ids.data_ptr(), + num_tokens_post_padded.data_ptr(), + M, N, K, EM, + num_valid_tokens, + stride_am, stride_ak,stride_be,stride_bk,stride_bn,stride_cm,stride_cn, + m_blck_sz, mul_routed_weight,top_k, + at::cuda::getCurrentCUDAStream(), CuCount); +} + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); @@ -103,5 +158,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("paged_attention_custom", &paged_attention_custom, "PagedAttention LL4Mi Custom."); m.def("wvSpltK", &wvSpltK); + m.def("wvSpltK_fsdMoe", &wvSpltK_fsdMoe); // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index f03d3da5a8f9c..e55e1510ec27f 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -1925,6 +1925,1419 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, #endif // defined(__HIP__MI300__) TODO: Add NAVI support + + +#undef M +#undef YTILE +#undef UNRL +#define UNRL 1 +//#define M_BLOCK 4 + +template +__global__ void +__launch_bounds__(WvPrGrp * THRDS) +wvSpltK_fsdMoe_hf_( + const DTYPE* __restrict__ A, + const DTYPE* __restrict__ B, + DTYPE* C, + const float* __restrict__ topk_weights, + const int* __restrict__ topk_ids, + const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, + const int M_in, const int N, const int K, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const bool mul_routed_weight, + const int top_k, + const int CuCount + ) { + bool PCML = (K * M_in > 32*1024); + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + if (!PCML) { + for (uint32_t k = 0; k < min(K * M_in, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * M_in, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + } + + int YW = (YTILE * WvPrGrp); + int TWC = (THRDS * WvPrGrp * A_CHUNK); + int TUC = (THRDS * UNRL * A_CHUNK); + uint32_t kBase = 0; + //find biggest k size that fits in LDS + uint32_t kFit = (32*1024)/M_BLOCK; + //kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple of TUC + kFit = (kFit%TUC==0) ? kFit : (kFit-kFit%TUC); //round down to multiple of TUC + //if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + //if (kFit < TUC) PCML = false; + + float sum[M_BLOCK][YTILE]; + + //TRITON + //offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + //offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + //token_mask = offs_token < num_valid_tokens + int offs_token[M_BLOCK]; + bool token_mask[M_BLOCK]; // add to A[] /top_k*k + int off_experts; // add to B[] *K*N loads + + uint32_t Nrndp = (N%YW==0) ? N : (N-N%YW+YW); // Note: All waves in the group need to stay alive to the bitter end, just in case they're needed for cooperative loading of next chunk of A[] into LDS. Such Zomby waves are prevented from doing any real work with continues in the loop below. + if (!PCML) Nrndp = N; //unless its not peicmeal + while (n < Nrndp) { + kBase = 0; + for (uint32_t e = 0; e < num_tokens_post_padded[0]; e+=M_BLOCK) { + kBase = 0; + + for (int m=0; m= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M_BLOCK; m++) { + if (!token_mask[m]) continue; + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + } + + // kept alive just to participate in A[] loads + if (n >= N) continue; + +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // load only 1 column of weights, despite the moe-gate, made possible by expert list. + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M_BLOCK; m++) + { + if (!token_mask[m]) continue; + if (PCML) { + //bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + kFit*m]))); + // skip A[] fetches for Ms that are disabled + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } else { + int aidx = k_ + (offs_token[m]/top_k) * K; + if (aidx + A_CHUNK <= 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[aidx]))); + else + bigA[m][k2] = *((const bigType*)(&(A[aidx]))); + } + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M_BLOCK; m++) { + // skip compute for Ms that are disabled + if (!token_mask[m]) continue; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + for (int y=0; y= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#define mfmaTILEn 16 +#define mfmaTILEk 4 +//#undef WvPrGrp +//#define WvPrGrp 8 +#define USEMFMA +//#define PIPELINED_33334x +//#define PIPELINED_556x +#define PIPELINED4x + +template +__global__ void +__launch_bounds__(WvPrGrp * THRDS) +wvSpltK_fsdMoe_hf_mfma16_( + const DTYPE* __restrict__ A, + const DTYPE* __restrict__ B, + DTYPE* C, + const float* __restrict__ topk_weights, + const int* __restrict__ topk_ids, + const int* __restrict__ sorted_token_ids, + const int* __restrict__ expert_ids, + const int* __restrict__ num_tokens_post_padded, + const int M_in, const int N, const int K, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const bool mul_routed_weight, + const int top_k, + const int CuCount + ) { + +using halfCxT = __attribute__((__vector_size__(mfmaTILEn * A_CHUNK / 2 * sizeof(float)))) float; +using halfC = __attribute__((__vector_size__(A_CHUNK / 2 * sizeof(float)))) float; +using halfT = __attribute__((__vector_size__(mfmaTILEk / 2 * sizeof(float)))) float; + +bool PCML = true;//(K * M_in > 32*1024); + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + int i[A_CHUNK / 2]; + long int l[A_CHUNK / 4]; + halfT hT[A_CHUNK / mfmaTILEk]; + halfC hC; + }; + union bigTypeXt{ + bigType B[mfmaTILEn]; + halfCxT hCT; + }; + + + __shared__ half s[1024 * 32]; + + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + int ETILE = (CuCount * WvPrGrp ) / (N/YTILE); // bump up etile to fill machine + if (ETILE < 1) ETILE = 1; //TODO: what is best default min ETILE? + if (M_in >= 128) ETILE = min(M_in/64, 15); // Heuristic: Add an ETILE for every 64 Ms + + const int num_tblk = num_tokens_post_padded[0] / M_BLOCK; + + // its worth spending time trying to load balance for this num_tokens... + if ((CuCount/(ETILE*2) > 0) && (num_tblk>0))// TODO: make sure all overflow/inf conditions are avoided + { + int nPrRnd0 = ((CuCount/(ETILE))*WvPrGrp)*YTILE; + int nRnds0 = (N + nPrRnd0 - 1 ) / nPrRnd0; + int tRnds0 = (num_tblk + (ETILE) - 1) / (ETILE); + int rnds0 = nRnds0 * tRnds0; + + int nPrRnd1n = ((CuCount/(ETILE/2))*WvPrGrp)*YTILE; + int nRnds1n = (N + nPrRnd1n - 1 ) / nPrRnd1n; + int tRnds1n = (num_tblk + (ETILE/2) - 1) / (ETILE/2); + int rnds1n = nRnds1n * tRnds1n; + + int nPrRnd1p = ((CuCount/(ETILE*2))*WvPrGrp)*YTILE; + int nRnds1p = (N + nPrRnd1p - 1 ) / nPrRnd1p; + int tRnds1p = (num_tblk + (ETILE*2) - 1) / (ETILE*2); + int rnds1p = nRnds1p * tRnds1p; + + int etl = ETILE; + if (rnds0 > rnds1n) { etl = ETILE/2; rnds0 = rnds1n; } + if (rnds0 > rnds1p) { etl = ETILE*2; rnds0 = rnds1p; } + ETILE = etl; + } + + uint32_t n = ((blockIdx.x/ETILE) * WvPrGrp + threadIdx.y) * YTILE; + +/* if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + }*/ + + if (!PCML) { + for (uint32_t k = 0; k < min(K * M_in, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + if (k_in >= min(K * M_in, 32 * 1024)) break; + + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + } + __syncthreads(); + } + + int YW = (YTILE * WvPrGrp); + int TWC = (THRDS * WvPrGrp * A_CHUNK); + int TUC = (THRDS * UNRL * A_CHUNK); + uint32_t kBase = 0; + //find biggest k size that fits in LDS + uint32_t kFit = (32*1024)/M_BLOCK; + //kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple of TUC + kFit = (kFit%TUC==0) ? kFit : (kFit-kFit%TUC); //round down to multiple of TUC + //if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + +#ifdef USEMFMA + using float4_ = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; + float4_ sum4; +#else + float sum[M_BLOCK][YTILE]; +#endif + + //TRITON + //offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + //offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + //token_mask = offs_token < num_valid_tokens + uint32_t offs_token[M_BLOCK]; + bool token_mask[M_BLOCK]; // add to A[] /top_k*k + uint32_t off_experts; // add to B[] *K*N loads + + int kShfl = A_CHUNK * THRDS * ( threadIdx.y + (threadIdx.x/16)); + int kSprd = A_CHUNK * ( threadIdx.x ); + + uint32_t Nrndp = (N%YW==0) ? N : (N-N%YW+YW); // Note: All waves in the group need to stay alive to the bitter end, just in case they're needed for cooperative loading of next chunk of A[] into LDS. Such Zomby waves are prevented from doing any real work with continues in the loop below. + if (!PCML) Nrndp = N; //unless its not peicmeal + while (n < Nrndp) { + kBase = 0; + for (uint32_t e = (blockIdx.x % ETILE) * M_BLOCK; e < num_tokens_post_padded[0]; e+=M_BLOCK*ETILE) { + kBase = 0; + +#pragma unroll M_BLOCK + for (uint32_t m=0; m= K) break; + if (kOff >= kFit) break; +#ifdef USEMFMA + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * K + kOff; // yes, K should be kFit here. but we'lltranspose this below anyway + // Transpose A for MFMAs + uint32_t k_in_x = (k_ot / A_CHUNK) % (K / A_CHUNK); + uint32_t k_in_y = (k_ot / A_CHUNK) / (K / A_CHUNK); + uint32_t k_ot_x = (k_in_x / mfmaTILEn) * mfmaTILEn + (k_in_y % mfmaTILEn); + uint32_t k_ot_y = (k_in_y / mfmaTILEn) * mfmaTILEn + (k_in_x % mfmaTILEn); + + k_ot = (k_ot_y * (kFit / A_CHUNK) + k_ot_x) * A_CHUNK; + + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + //} +#else + //int m = threadIdx.x % M_BLOCK; + //for (uint32_t m = 0; m < M_BLOCK; m++) { + //if (!token_mask[m]) continue; + uint32_t k_in = kBase + (offs_token[m]/top_k) * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + //} +#endif + } + __syncthreads(); + } + } + + // kept alive just to participate in A[] loads + if (n >= N) continue; + + int k1 = k1_; + if (shflk) k1 = kBase + (((k1_-kBase) + kShfl) % kFit ); // shfl loads within this lane, to reduce temporal hotspotting + + #define StgMfma4(_LN) { \ + for (uint32_t _t = 0; _t < A_CHUNK/mfmaTILEk; _t++) { \ + sum4 = __builtin_amdgcn_mfma_f32_16x16x16f16( \ + bigB[0][k2].B[_LN].hT[_t], \ + bigA[_LN][k2].hT[_t], \ + sum4, 0, 0, 0); \ + } \ + } + + +#ifdef PIPELINED1x +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/2; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/2; y= K) break; + for (int m = M_BLOCK/2; m < M_BLOCK; m++) + { + bigA[m-M_BLOCK/2][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/4; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/4; y= K) break; + for (int m = M_BLOCK/4; m < M_BLOCK/2; m++) + { + bigA[m-M_BLOCK/4][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/2; y<3*YTILE/4; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-YTILE/2].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = M_BLOCK/2; m < 3*M_BLOCK/4; m++) + { + bigA[m-M_BLOCK/2][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3*YTILE/4; y= K) break; + for (int m = 3*M_BLOCK/4; m < M_BLOCK; m++) + { + bigA[m-3*M_BLOCK/4][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y<3; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 0; m < 3; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 2////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3; y<6; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/4].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-3].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 3; m < 6; m++) + { + bigA[m-3][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } +///////////////////////////ROUND 3////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=6; y<9; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-6].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 6; m < 9; m++) + { + bigA[m-6][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 4////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=9; y<12; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-9].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 9; m < 12; m++) + { + bigA[m-9][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<3; l++) + StgMfma4(l); + } + +///////////////////////////ROUND 5////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=12; y<16; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-12].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 12; m < 16; m++) + { + bigA[m-12][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<4; l++) + StgMfma4(l); + } + + + + +#elif defined(PIPELINED_556x) //556x + +///////////////////////////ROUND 1////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y<5; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 0; m < 5; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<5; l++) + StgMfma4(l); + //} + +///////////////////////////ROUND 2////////////////////////// +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + kSprd; + // if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=5; y<10; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/4].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-5].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 5; m < 10; m++) + { + bigA[m-5][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<5; l++) + StgMfma4(l); + //} +///////////////////////////ROUND 3////////////////////////// + //#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + kSprd; + // if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=10; y<16; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/2].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-10].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 10; m < 16; m++) + { + bigA[m-10][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l<6; l++) + StgMfma4(l); + } + +#elif defined(PIPELINED8x) //8x + +///////////////////////////ROUND 1////////////////////////// +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=0; y= K) break; + for (int m = 0; m < M_BLOCK/8; m++) + { + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=YTILE/8; y<2*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = M_BLOCK/8; m < 2*M_BLOCK/8; m++) + { + bigA[m-M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=2*YTILE/8; y<3*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-2*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-2*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 2*M_BLOCK/8; m < 3*M_BLOCK/8; m++) + { + bigA[m-2*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=3*YTILE/8; y<4*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-3*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-3*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 3*M_BLOCK/8; m < 4*M_BLOCK/8; m++) + { + bigA[m-3*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=4*YTILE/8; y<5*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-4*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-4*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 4*M_BLOCK/8; m < 5*M_BLOCK/8; m++) + { + bigA[m-4*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=5*YTILE/8; y<6*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-5*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-5*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 5*M_BLOCK/8; m < 6*M_BLOCK/8; m++) + { + bigA[m-5*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=6*YTILE/8; y<7*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-6*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-6*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 6*M_BLOCK/8; m < 7*M_BLOCK/8; m++) + { + bigA[m-6*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; + for (int y=7*YTILE/8; y<8*YTILE/8; y++) // should this be M_BLOCK? + //bigB[0][k2].B[y-7*YTILE/8].hC = (loadnt((halfC*)(&B_[y * K]))); + bigB[0][k2].B[y-7*YTILE/8].hC = *(((halfC*)(&B_[y * K]))); + } +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + kSprd; + if (k_ >= K) break; + for (int m = 7*M_BLOCK/8; m < 8*M_BLOCK/8; m++) + { + bigA[m-7*M_BLOCK/8][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } + //} +//#pragma unroll + //for (uint32_t k2 = 0; k2 < UNRL; k2++) { + // uint32_t k = k1 + k2 * THRDS * A_CHUNK; + // uint32_t k_ = k + threadIdx.x * A_CHUNK; + // if (k_ >= K) break; + for (int l=0; l= K) break; + + const half* B_ = &B[(n + 0) * K + k_ + off_experts*K*N]; +#ifdef USEMFMA + for (int y=0; y= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M_BLOCK; m++) + { +#ifdef USEMFMA +#else + if (!token_mask[m]) continue; +#endif + if (PCML) { + //bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + kFit*m]))); + // skip A[] fetches for Ms that are disabled + bigA[m][k2] = *((const bigType*)(&(s[k_-kBase + m*kFit ]))); + } else { + int aidx = k_ + (offs_token[m]/top_k) * K; + if (aidx + A_CHUNK <= 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[aidx]))); + else + bigA[m][k2] = *((const bigType*)(&(A[aidx]))); + } + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + +#ifdef USEMFMA + bigType stgB; + for (int l=0; l= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + + + +// a = torch.randn((m, k) +// b1 = torch.randn((e, 2 * n, k) +// b2 = torch.randn((e, k, n) +// topk_weights = torch.randn((m, e), device='cuda', dtype=dtype) + +void wvSpltK_fsdMoe_(void* in_a, void* in_b, void* out_c, + void* topk_weights, + void* topk_ids, + void* sorted_token_ids, + void* expert_ids, + void* num_tokens_post_padded, + const int M_in, const int N_in, const int K_in, const int E, + const int num_valid_tokens, + const int stride_am, + const int stride_ak, + const int stride_be, + const int stride_bk, + const int stride_bn, + const int stride_cm, + const int stride_cn, + const int m_blck_sz, + const bool mul_routed_weight, + const int top_k, + cudaStream_t stream, const int CuCount) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + auto* a = reinterpret_cast(in_a); + auto* b = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + auto* topk_weights_ = reinterpret_cast(topk_weights); + auto* topk_ids_ = reinterpret_cast(topk_ids); + auto* sorted_token_ids_ = reinterpret_cast(sorted_token_ids); + auto* expert_ids_ = reinterpret_cast(expert_ids); + auto* num_tokens_post_padded_ = reinterpret_cast(num_tokens_post_padded); + switch (m_blck_sz) { + case 1: + wvSpltK_fsdMoe_hf_<1,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 2: + wvSpltK_fsdMoe_hf_<2,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 3: + wvSpltK_fsdMoe_hf_<3,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 4: + wvSpltK_fsdMoe_hf_<4,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 5: + wvSpltK_fsdMoe_hf_<5,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 6: + wvSpltK_fsdMoe_hf_<6,4><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + case 16: + wvSpltK_fsdMoe_hf_mfma16_<16,16><<>>(a, b, c, topk_weights_, topk_ids_, sorted_token_ids_, expert_ids_, num_tokens_post_padded_, M_in, N_in, K_in, E, num_valid_tokens, stride_am, stride_ak, stride_be, stride_bk, stride_bn, stride_cm, stride_cn, mul_routed_weight, top_k, CuCount); + break; + + } +} + void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, const int K_in, const int N_in, cudaStream_t stream, const int CuCount = 0) { diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 2356b9ec18b0d..8e3d707b8f149 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -10,7 +10,17 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.models.mixtral import MixtralMoE +from vllm import envs +def permute_weight(x: torch.Tensor) -> torch.Tensor: + x_ = x.clone() + x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//32, 4, 8) + x_ = x_.permute(0,1,3,4,2,5) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); + return x_ def torch_moe(a, w1, w2, score, topk): B, D = a.shape @@ -52,6 +62,62 @@ def test_fused_moe( torch_output = torch_moe(a, w1, w2, score, topk) assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) +@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +@pytest.mark.parametrize("n", [14336]) +@pytest.mark.parametrize("k", [4096]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_amd_moe_1( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + if n == k: + pytest.skip() + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + if envs.VLLM_MOE_SHUFFLE: + w1_shuffled = permute_weight(w1.data) + w2_shuffled = permute_weight(w2.data) + + score = torch.randn((m, e), device='cuda', dtype=dtype) + triton_output = fused_moe(a, w1_shuffled, w2_shuffled, score, topk, renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk) + assert torch.allclose(triton_output, torch_output, atol=2e-2, rtol=0) + + +@pytest.mark.parametrize("m", [1, 64, 96, 1000, 237]) +@pytest.mark.parametrize("n", [4096]) +@pytest.mark.parametrize("k", [14336]) +@pytest.mark.parametrize("e", [8]) +@pytest.mark.parametrize("topk", [2]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_amd_moe_2( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + if n == k: + pytest.skip() + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + if envs.VLLM_MOE_SHUFFLE: + w1_shuffled = permute_weight(w1.data) + w2_shuffled = permute_weight(w2.data) + + score = torch.randn((m, e), device='cuda', dtype=dtype) + triton_output = fused_moe(a, w1_shuffled, w2_shuffled, score, topk, renormalize=False) + torch_output = torch_moe(a, w1, w2, score, topk) + assert torch.allclose(triton_output, torch_output, atol=2e-1, rtol=0) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7c70b1b244f7d..67daabc9b0fd3 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -23,6 +23,9 @@ ARTIFICIAL_PREEMPTION_PROB = 0.5 ARTIFICIAL_PREEMPTION_MAX_CNT = 500 +VLLM_SCHED_PREFILL_COUNT = int( + os.getenv("VLLM_SCHED_PREFILL_COUNT", 0)) # noqa + class PreemptionMode(enum.Enum): """Preemption modes. @@ -263,7 +266,15 @@ def __init__( # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. self.lora_config = lora_config - + self.prefill_timeout = 0 + + # slightly hackey, but if you specify prefill batch count, the delay factor + # needs to exist, otherwise we will always skip. Default will be equal to + # VLLM_SCHED_PREFILL_COUNT, as they should be roughly the same. + # Recommend setting with --scheduler-delay-factor and experimenting + # On command line + if VLLM_SCHED_PREFILL_COUNT > 0 and self.scheduler_config.delay_factor == 0: + self.scheduler_config.delay_factor = VLLM_SCHED_PREFILL_COUNT version = "v1" if self.scheduler_config.use_v2_block_manager: version = "v2" @@ -644,7 +655,8 @@ def _schedule_prefills( waiting_queue = deque([s for s in waiting_queue]) leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: + + while (VLLM_SCHED_PREFILL_COUNT <= len(waiting_queue) or self._passed_delay(time.time())) and waiting_queue: seq_group = waiting_queue[0] waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) @@ -719,7 +731,6 @@ def _schedule_prefills( waiting_queue.extendleft(leftover_waiting_sequences) if len(seq_groups) > 0: self.prev_prompt = True - return waiting_queue, SchedulerPrefillOutputs( seq_groups=seq_groups, ignored_seq_groups=ignored_seq_groups, diff --git a/vllm/envs.py b/vllm/envs.py index 739a4792ce078..281478283c244 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -43,7 +43,11 @@ VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = True + VLLM_MOE_SHUFFLE: bool = False + FUSED_MOE_PERSISTENT: bool = False + VLLM_MOE_MFMASWIZZLE: bool = True + VLLM_MOE_MFMASWIZZLE_M_THRSHLD: int = 32 # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -246,6 +250,21 @@ # Pad the weight for moe kernel or not "VLLM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))), + + # shuffle the weight for moe kernel or not + "VLLM_MOE_SHUFFLE": + lambda: bool(int(os.getenv("VLLM_MOE_SHUFFLE", "0"))), + + # User persistent version of fused_moe Triton kernel + "FUSED_MOE_PERSISTENT": + lambda: bool(int(os.getenv("FUSED_MOE_PERSISTENT", "0"))), + + # hashem's swizzle + # Swizzle the weights for mfma ops in moe kernel, or not + "VLLM_MOE_MFMASWIZZLE": + lambda: bool(int(os.getenv("VLLM_MOE_MFMASWIZZLE", "1"))), + "VLLM_MOE_MFMASWIZZLE_M_THRSHLD": + lambda: int(os.getenv("VLLM_MOE_MFMASWIZZLE_M_THRSHLD", "32")), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json index 9de6d6a479184..b294a1c08d6f6 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -69,7 +69,7 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 0, "waves_per_eu": 0, @@ -77,11 +77,11 @@ "kpack": 2 }, "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 1, + "num_warps": 8, "num_stages": 0, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, @@ -89,10 +89,10 @@ }, "64": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 0, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, @@ -100,7 +100,7 @@ }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, @@ -139,16 +139,16 @@ "num_warps": 4, "num_stages": 0, "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, + "matrix_instr_nonkdim": 16, "kpack": 2 }, "1024": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -159,7 +159,7 @@ "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -168,9 +168,9 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -181,7 +181,7 @@ "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -189,10 +189,10 @@ "4096": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -200,10 +200,10 @@ "16384": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -211,10 +211,10 @@ "18432": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 @@ -222,10 +222,10 @@ "20480": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 0, + "num_stages": 1, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json index f020993a1b615..9e8b4de93747a 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json @@ -128,7 +128,7 @@ "num_warps": 4, "num_stages": 0, "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, + "matrix_instr_nonkdim": 16, "kpack": 1 }, "512": { diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index e759d63b588b3..7c009e8745006 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -9,10 +9,12 @@ import triton.language as tl import vllm._moe_C as moe_kernels -from vllm import _custom_ops as ops from vllm import envs +from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm import _custom_C + logger = init_logger(__name__) padding_size = 128 if envs.VLLM_MOE_PADDING else 0 @@ -134,7 +136,7 @@ def fused_moe_kernel( a_ptrs, mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, + other=0.0 ) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, @@ -167,6 +169,182 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def fused_moe_persistent_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, + NUM_SMS: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + This is the persistent version of the fused_moe kernel. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Simply compute how many iterations each persistent block needs to do + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # num_tiles = num_pid_m * num_pid_n + tile_id = start_pid + + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_token = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) + # token_mask = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int1) + + # Load tile-invariant runtime constant + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Compute how many tiles are outside the padding region + num_pid_in_group = GROUP_SIZE_M * num_pid_n + pid_m = 0 + tile_id2 = start_pid - NUM_SMS + num_valid_tiles = -1 + while pid_m * BLOCK_SIZE_M < num_tokens_post_padded: + num_valid_tiles += 1 + tile_id2 += NUM_SMS + group_id = tile_id2 // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) + + for _ in range(0, num_valid_tiles): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + # Compute the mask + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + # Compute the A pointer + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + # Compute the B pointer + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + if EVEN_K: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0 + ) + # We accumulate along the K dimension. + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + # advance tile_id + tile_id += NUM_SMS + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -213,7 +391,7 @@ def moe_align_block_size( device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - expert_ids = torch.empty((max_num_m_blocks, ), + expert_ids = torch.zeros((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), @@ -229,43 +407,30 @@ def moe_align_block_size( ) return sorted_ids, expert_ids, num_tokens_post_pad - -def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], +def invoke_mega_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, top_k: int, - config: Dict[str, Any], compute_type: tl.dtype, + m_blck_sz: int, mul_routed_weight: bool, top_k: int, use_fp8: bool) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if not use_fp8: - assert A_scale is None - assert B_scale is None - else: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) - assert B_scale is not None - - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) - - fused_moe_kernel[grid]( + #print("\nm=",A.shape[0],"n=",B.shape[1],"k=",B.shape[2],"e=", B.shape[0], "ml_rt:",mul_routed_weight,"tpk",top_k, "\n") + _custom_C.wvSpltK_fsdMoe(#A, B, C, B.shape[1], 80) A, B, C, - A_scale, - B_scale, topk_weights, + topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, + A.shape[0], B.shape[1], B.shape[2] - padding_size, - sorted_token_ids.shape[0], + B.shape[0], topk_ids.numel(), A.stride(0), A.stride(1), @@ -274,12 +439,100 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, B.stride(1), C.stride(1), C.stride(2), - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8=use_fp8, - **config, - ) + m_blck_sz, + mul_routed_weight, + top_k, + 80) + +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if not use_fp8: + assert A_scale is None + assert B_scale is None + else: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + + if not envs.FUSED_MOE_PERSISTENT: + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8, + **config, + enable_moe_lds_bypass=True + ) + else: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count * 2 + grid = lambda META: (min( + NUM_SMS, + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * + triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) + ), ) + + fused_moe_persistent_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + NUM_SMS=NUM_SMS, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8, + **config, + enable_moe_lds_bypass=True + ) def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @@ -367,8 +620,7 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): # Check constraints. - assert hidden_states.shape[ - 1] == w1.shape[2] - padding_size, "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -398,6 +650,7 @@ def fused_experts(hidden_states: torch.Tensor, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, + "num_stages": 1, } if M <= E: @@ -406,8 +659,9 @@ def fused_experts(hidden_states: torch.Tensor, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, + "num_stages": 1, } - + intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), device=hidden_states.device, @@ -424,12 +678,87 @@ def fused_experts(hidden_states: torch.Tensor, dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) compute_type = (tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16) + #print(hidden_states.shape) + #print(intermediate_cache2.shape) + #print("M1:", hidden_states.shape[0], "M2:", intermediate_cache2.shape[0]) + #if hidden_states.shape[0] <= 256 and hidden_states.shape[1] % 8 == 0 and intermediate_cache2.shape[0] <= 256 and not use_fp8 : + + #WVSPLTK_M_THRSHLD = 64 + #if hidden_states.shape[0] <= WVSPLTK_M_THRSHLD \ + # and hidden_states.shape[1] % 8 == 0 \ + # and intermediate_cache2.shape[0] <= WVSPLTK_M_THRSHLD \ + # and intermediate_cache2.shape[1] % 8 == 0 \ + # and not use_fp8 : + if envs.VLLM_MOE_MFMASWIZZLE and M<=envs.VLLM_MOE_MFMASWIZZLE_M_THRSHLD: + assert(compute_type == tl.float16, "Only fp16 supported for wvSplitK_mfma16x16 for now") + #m_blck_sz = -(-(M*topk_ids.shape[1]*3)//E) # target 75% of expert distribution for this M size + #if (m_blck_sz >= 12): + # m_blck_sz = 16 + + # all calls go to wvSplitK_mfma16x16 + m_blck_sz = 16 # TODO: this is for decode stage, need another for prefill + #print("M:", M, " M_BLOCK PICKED:", m_blck_sz) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, m_blck_sz, E) # target 75% of expert distribution for this M size + #topk_ids, config2['BLOCK_SIZE_M'],E) + #print("\nsrtd_tkn:", sorted_token_ids) + #print("w1Shape:",w1.shape) + + #env VLLM_MOE_MFMASWIZZLE does this swizzle on init + w1_ = w1 + w2_ = w2 + if not envs.VLLM_MOE_MFMASWIZZLE : # for debug only + if m_blck_sz >= 16 : + w1_ = torch.clone(w1) + w1_ = w1_.view(w1.shape[0], w1.shape[1]//16, 16, w1.shape[2]//128, 16, 8); + w1_ = w1_.permute(0, 1, 4, 3, 2, 5) + w1_ = w1_.contiguous() + w1_ = w1_.view(w1.shape[0],w1.shape[1],w1.shape[2]); + w2_ = torch.clone(w2) + w2_ = w2_.view(w2.shape[0], w2.shape[1]//16, 16, w2.shape[2]//128, 16, 8); + w2_ = w2_.permute(0, 1, 4, 3, 2, 5) + w2_ = w2_.contiguous() + w2_ = w2_.view(w2.shape[0],w2.shape[1],w2.shape[2]); + + #print(w1_) + + invoke_mega_fused_moe_kernel(hidden_states, + w1_, + intermediate_cache1, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + m_blck_sz, + False, + topk_ids.shape[1], + use_fp8=use_fp8) + #print("shdr_invk1:",intermediate_cache1.view(-1, N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + #print("shdr_silu:",intermediate_cache2) + #print("shdr_silu_shape:", intermediate_cache2.shape) + #print("-----------------------------") + + invoke_mega_fused_moe_kernel(intermediate_cache2, + w2_, + intermediate_cache3, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + m_blck_sz, + True, + 1, + use_fp8=use_fp8) - invoke_fused_moe_kernel(hidden_states, + else: + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, a1_scale, @@ -445,9 +774,9 @@ def fused_experts(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8=use_fp8) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - invoke_fused_moe_kernel(intermediate_cache2, + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, a2_scale, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ee9db7048f1f6..f974028157a74 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -183,15 +183,17 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def process_weights_after_loading(self): # Fp8 is the only case where we need to process after loading. if not self.use_fp8: + w13_ = permute_weight(self.w13_weight.data) + w2_ = permute_weight(self.w2_weight.data) if envs.VLLM_MOE_PADDING: - self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, - (0, 128), "constant", 0), - requires_grad=False) + w13_ = F.pad(w13_, (0, 128), "constant", 0) torch.cuda.empty_cache() - self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, - (0, 128), "constant", 0), - requires_grad=False) + w2_ = F.pad(w2_, (0, 128), "constant", 0) torch.cuda.empty_cache() + self.w13_weight = nn.Parameter(w13_, requires_grad=False) + torch.cuda.empty_cache() + self.w2_weight = nn.Parameter(w2_, requires_grad=False) + torch.cuda.empty_cache() return # If checkpoint is fp16, quantize here. @@ -603,3 +605,21 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def all_close_1d(x: torch.Tensor) -> bool: assert len(x.shape) == 1 return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + +def permute_weight(x: torch.Tensor) -> torch.Tensor: + x_ = x + if envs.VLLM_MOE_SHUFFLE: + x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//32, 4, 8) + x_ = x_.permute(0,1,3,4,2,5) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); + elif envs.VLLM_MOE_MFMASWIZZLE: # hashem's swizzle + x_ = x_.view(x.shape[0], + x.shape[1]//16, 16, + x.shape[2]//128, 16, 8) + x_ = x_.permute(0,1,4,3,2,5) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]); + return x_