diff --git a/python/mlc_llm/op/moe_misc.py b/python/mlc_llm/op/moe_misc.py index 2c834595fc..a897ecba8b 100644 --- a/python/mlc_llm/op/moe_misc.py +++ b/python/mlc_llm/op/moe_misc.py @@ -56,162 +56,87 @@ def gating_softmax_topk( # pylint: disable=too-many-statements index_dtype = "int32" TX = 1024 - SCAN_LEN_2 = 2 - SCAN_LEN_4 = 4 - # specialized kernel for top 2 case - @T.prim_func(private=True) - def top2_softmax_norm_func( - var_x: T.handle, - var_out: T.handle, - var_out_index: T.handle, - ) -> None: - T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) - batch_size = T.int64() - x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) - out = T.match_buffer(var_out, (batch_size, SCAN_LEN_2), dtype) - out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_2), index_dtype) - local_top_k = T.alloc_buffer((SCAN_LEN_2,), dtype=dtype, scope="local") - local_top_k_index = T.alloc_buffer((SCAN_LEN_2,), dtype=index_dtype, scope="local") - local_top_k_f32 = T.alloc_buffer((SCAN_LEN_2,), dtype="float32", scope="local") - local_top_k_max = T.alloc_buffer((1,), dtype="float32", scope="local") - for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): - for ii in T.thread_binding(0, TX, "threadIdx.x"): - with T.block("top_k"): - vi = T.axis.spatial(batch_size, io * TX + ii) - T.where(io * TX + ii < batch_size) - with T.block("init"): - local_top_k[0] = T.min_value(dtype) - local_top_k[1] = T.min_value(dtype) - local_top_k_index[0] = 0 - local_top_k_index[1] = 1 - for k in range(num_local_experts): - with T.block("update"): - vk = T.axis.remap("S", [k]) - # N.B. This snippet is specialized for k = 2 - if x[vi, vk] > local_top_k[0]: - local_top_k[1] = local_top_k[0] - local_top_k_index[1] = local_top_k_index[0] - local_top_k[0] = x[vi, vk] - local_top_k_index[0] = vk - elif x[vi, vk] > local_top_k[1]: - local_top_k[1] = x[vi, vk] - local_top_k_index[1] = vk - for j in T.unroll(SCAN_LEN_2): - with T.block("cast"): - vj = T.axis.remap("S", [j]) - local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32") - with T.block("max"): - local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1]) - for j in T.unroll(SCAN_LEN_2): - with T.block("output"): - vj = T.axis.remap("S", [j]) - out[vi, vj] = T.cast( - T.exp(local_top_k_f32[vj] - local_top_k_max[0]) - / ( - T.exp(local_top_k_f32[0] - local_top_k_max[0]) - + T.exp(local_top_k_f32[1] - local_top_k_max[0]) - ), - dtype, - ) - out_index[vi, vj] = local_top_k_index[vj] - - # specialized kernel for top 4 case - @T.prim_func(private=True) - def top4_softmax_norm_func( - var_x: T.handle, - var_out: T.handle, - var_out_index: T.handle, - ) -> None: - T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) - batch_size = T.int64() - x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) - out = T.match_buffer(var_out, (batch_size, SCAN_LEN_4), dtype) - out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_4), index_dtype) - local_top_k = T.alloc_buffer((SCAN_LEN_4,), dtype=dtype, scope="local") - local_top_k_index = T.alloc_buffer((SCAN_LEN_4,), dtype=index_dtype, scope="local") - for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): - for ii in T.thread_binding(0, TX, "threadIdx.x"): - with T.block("top_k"): - vi = T.axis.spatial(batch_size, io * TX + ii) - T.where(io * TX + ii < batch_size) - with T.block("init"): - local_top_k[0] = T.min_value(dtype) - local_top_k[1] = T.min_value(dtype) - local_top_k[2] = T.min_value(dtype) - local_top_k[3] = T.min_value(dtype) - local_top_k_index[0] = 0 - local_top_k_index[1] = 1 - local_top_k_index[2] = 2 - local_top_k_index[3] = 3 - for k in range(num_local_experts): - with T.block("update"): - vk = T.axis.remap("S", [k]) - # N.B. This snippet is specialized for k = 4 - if x[vi, vk] > local_top_k[0]: - local_top_k[3] = local_top_k[2] - local_top_k_index[3] = local_top_k_index[2] - local_top_k[2] = local_top_k[1] - local_top_k_index[2] = local_top_k_index[1] - local_top_k[1] = local_top_k[0] - local_top_k_index[1] = local_top_k_index[0] - local_top_k[0] = x[vi, vk] - local_top_k_index[0] = vk - elif x[vi, vk] > local_top_k[1]: - local_top_k[3] = local_top_k[2] - local_top_k_index[3] = local_top_k_index[2] - local_top_k[2] = local_top_k[1] - local_top_k_index[2] = local_top_k_index[1] - local_top_k[1] = x[vi, vk] - local_top_k_index[1] = vk - elif x[vi, vk] > local_top_k[2]: - local_top_k[3] = local_top_k[2] - local_top_k_index[3] = local_top_k_index[2] - local_top_k[2] = x[vi, vk] - local_top_k_index[2] = vk - elif x[vi, vk] > local_top_k[3]: - local_top_k[3] = x[vi, vk] - local_top_k_index[3] = vk - for j in T.unroll(SCAN_LEN_4): - with T.block("output"): - vj = T.axis.remap("S", [j]) - out[vi, vj] = local_top_k[vj] - out_index[vi, vj] = local_top_k_index[vj] - - # fast path for Mixtral - if k == 2 and norm_topk_prob: + def _get_topk_softmax_norm_func(k_val: int): + def _init_local_top_k(local_top_k, local_top_k_index): + for t in range(k_val): + T.buffer_store(local_top_k, T.min_value(dtype), indices=[t]) + for t in range(k_val): + T.buffer_store(local_top_k_index, t, indices=[t]) + + def _process_value(x, local_top_k, local_top_k_index, vi, vk): + if_frames = [T.If(x[vi, vk] > local_top_k[i]) for i in range(k_val)] + then_frames = [T.Then() for _ in range(k_val)] + else_frames = [T.Else() for _ in range(k_val - 1)] + for i in range(k_val): + if_frames[i].__enter__() # pylint: disable=unnecessary-dunder-call + with then_frames[i]: + for j in range(k_val - 1, i, -1): + T.buffer_store(local_top_k, local_top_k[j - 1], indices=[j]) + T.buffer_store(local_top_k_index, local_top_k_index[j - 1], indices=[j]) + T.buffer_store(local_top_k, x[vi, vk], indices=[i]) + T.buffer_store(local_top_k_index, vk, indices=[i]) + if i != k_val - 1: + else_frames[i].__enter__() # pylint: disable=unnecessary-dunder-call + + for i in range(k_val - 1, -1, -1): + if i != k_val - 1: + else_frames[i].__exit__(None, None, None) + if_frames[i].__exit__(None, None, None) + + @T.prim_func(private=True) + def topk_softmax_norm_func( + var_x: T.handle, + var_out: T.handle, + var_out_index: T.handle, + ) -> None: + T.func_attr({"tir.noalias": True, "tir.is_scheduled": True}) + batch_size = T.int64() + x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype) + out = T.match_buffer(var_out, (batch_size, k_val), dtype) + out_index = T.match_buffer(var_out_index, (batch_size, k_val), index_dtype) + local_top_k = T.alloc_buffer((k_val,), dtype=dtype, scope="local") + local_top_k_index = T.alloc_buffer((k_val,), dtype=index_dtype, scope="local") + for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"): + for ii in T.thread_binding(0, TX, "threadIdx.x"): + with T.block("top_k"): + vi = T.axis.spatial(batch_size, io * TX + ii) + T.where(io * TX + ii < batch_size) + with T.block("init"): + _init_local_top_k(local_top_k, local_top_k_index) + for k in range(num_local_experts): + with T.block("update"): + vk = T.axis.remap("S", [k]) + _process_value(x, local_top_k, local_top_k_index, vi, vk) + for j in T.unroll(k_val): + with T.block("output"): + vj = T.axis.remap("S", [j]) + out[vi, vj] = local_top_k[vj] + out_index[vi, vj] = local_top_k_index[vj] + + return topk_softmax_norm_func + + if norm_topk_prob: return op.tensor_ir_op( - top2_softmax_norm_func, - "top2_softmax", + _get_topk_softmax_norm_func(k), + f"top{k}_softmax", args=[x], out=( - Tensor.placeholder([batch_size, 2], dtype), - Tensor.placeholder([batch_size, 2], index_dtype), - ), - ) - if k == 4 and not norm_topk_prob: - expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype) - return op.tensor_ir_op( - top4_softmax_norm_func, - "top4_softmax", - args=[expert_score], - out=( - Tensor.placeholder([batch_size, 4], dtype), - Tensor.placeholder([batch_size, 4], index_dtype), + Tensor.placeholder([batch_size, k], dtype), + Tensor.placeholder([batch_size, k], index_dtype), ), ) - if norm_topk_prob: - # Compute topk first and then softmax to avoid extra re-normalize - expert_score, expert_indices = op.topk( - x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype - ) - expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype) - else: - expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype) - expert_score, expert_indices = op.topk( - expert_score, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype - ) - return expert_score, expert_indices + + expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype) + return op.tensor_ir_op( + _get_topk_softmax_norm_func(k), + f"top{k}_softmax", + args=[expert_score], + out=( + Tensor.placeholder([batch_size, k], dtype), + Tensor.placeholder([batch_size, k], index_dtype), + ), + ) def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor: