Skip to content

Commit

Permalink
[Python] Update top-k kernel with TIR meta-programming (mlc-ai#3079)
Browse files Browse the repository at this point in the history
This PR introduces the meta-programming top-k kernel for
expert selection in MoE. With this PR, we can deprecate the
previous ad-hoc kernels specialized for k=2 and k=4.
  • Loading branch information
MasterJH5574 authored Jan 2, 2025
1 parent 1825fed commit 3463ab7
Showing 1 changed file with 75 additions and 150 deletions.
225 changes: 75 additions & 150 deletions python/mlc_llm/op/moe_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,162 +56,87 @@ def gating_softmax_topk( # pylint: disable=too-many-statements
index_dtype = "int32"

TX = 1024
SCAN_LEN_2 = 2
SCAN_LEN_4 = 4

# specialized kernel for top 2 case
@T.prim_func(private=True)
def top2_softmax_norm_func(
var_x: T.handle,
var_out: T.handle,
var_out_index: T.handle,
) -> None:
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
batch_size = T.int64()
x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)
out = T.match_buffer(var_out, (batch_size, SCAN_LEN_2), dtype)
out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_2), index_dtype)
local_top_k = T.alloc_buffer((SCAN_LEN_2,), dtype=dtype, scope="local")
local_top_k_index = T.alloc_buffer((SCAN_LEN_2,), dtype=index_dtype, scope="local")
local_top_k_f32 = T.alloc_buffer((SCAN_LEN_2,), dtype="float32", scope="local")
local_top_k_max = T.alloc_buffer((1,), dtype="float32", scope="local")
for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"):
for ii in T.thread_binding(0, TX, "threadIdx.x"):
with T.block("top_k"):
vi = T.axis.spatial(batch_size, io * TX + ii)
T.where(io * TX + ii < batch_size)
with T.block("init"):
local_top_k[0] = T.min_value(dtype)
local_top_k[1] = T.min_value(dtype)
local_top_k_index[0] = 0
local_top_k_index[1] = 1
for k in range(num_local_experts):
with T.block("update"):
vk = T.axis.remap("S", [k])
# N.B. This snippet is specialized for k = 2
if x[vi, vk] > local_top_k[0]:
local_top_k[1] = local_top_k[0]
local_top_k_index[1] = local_top_k_index[0]
local_top_k[0] = x[vi, vk]
local_top_k_index[0] = vk
elif x[vi, vk] > local_top_k[1]:
local_top_k[1] = x[vi, vk]
local_top_k_index[1] = vk
for j in T.unroll(SCAN_LEN_2):
with T.block("cast"):
vj = T.axis.remap("S", [j])
local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32")
with T.block("max"):
local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1])
for j in T.unroll(SCAN_LEN_2):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = T.cast(
T.exp(local_top_k_f32[vj] - local_top_k_max[0])
/ (
T.exp(local_top_k_f32[0] - local_top_k_max[0])
+ T.exp(local_top_k_f32[1] - local_top_k_max[0])
),
dtype,
)
out_index[vi, vj] = local_top_k_index[vj]

# specialized kernel for top 4 case
@T.prim_func(private=True)
def top4_softmax_norm_func(
var_x: T.handle,
var_out: T.handle,
var_out_index: T.handle,
) -> None:
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
batch_size = T.int64()
x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)
out = T.match_buffer(var_out, (batch_size, SCAN_LEN_4), dtype)
out_index = T.match_buffer(var_out_index, (batch_size, SCAN_LEN_4), index_dtype)
local_top_k = T.alloc_buffer((SCAN_LEN_4,), dtype=dtype, scope="local")
local_top_k_index = T.alloc_buffer((SCAN_LEN_4,), dtype=index_dtype, scope="local")
for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"):
for ii in T.thread_binding(0, TX, "threadIdx.x"):
with T.block("top_k"):
vi = T.axis.spatial(batch_size, io * TX + ii)
T.where(io * TX + ii < batch_size)
with T.block("init"):
local_top_k[0] = T.min_value(dtype)
local_top_k[1] = T.min_value(dtype)
local_top_k[2] = T.min_value(dtype)
local_top_k[3] = T.min_value(dtype)
local_top_k_index[0] = 0
local_top_k_index[1] = 1
local_top_k_index[2] = 2
local_top_k_index[3] = 3
for k in range(num_local_experts):
with T.block("update"):
vk = T.axis.remap("S", [k])
# N.B. This snippet is specialized for k = 4
if x[vi, vk] > local_top_k[0]:
local_top_k[3] = local_top_k[2]
local_top_k_index[3] = local_top_k_index[2]
local_top_k[2] = local_top_k[1]
local_top_k_index[2] = local_top_k_index[1]
local_top_k[1] = local_top_k[0]
local_top_k_index[1] = local_top_k_index[0]
local_top_k[0] = x[vi, vk]
local_top_k_index[0] = vk
elif x[vi, vk] > local_top_k[1]:
local_top_k[3] = local_top_k[2]
local_top_k_index[3] = local_top_k_index[2]
local_top_k[2] = local_top_k[1]
local_top_k_index[2] = local_top_k_index[1]
local_top_k[1] = x[vi, vk]
local_top_k_index[1] = vk
elif x[vi, vk] > local_top_k[2]:
local_top_k[3] = local_top_k[2]
local_top_k_index[3] = local_top_k_index[2]
local_top_k[2] = x[vi, vk]
local_top_k_index[2] = vk
elif x[vi, vk] > local_top_k[3]:
local_top_k[3] = x[vi, vk]
local_top_k_index[3] = vk
for j in T.unroll(SCAN_LEN_4):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = local_top_k[vj]
out_index[vi, vj] = local_top_k_index[vj]

# fast path for Mixtral
if k == 2 and norm_topk_prob:
def _get_topk_softmax_norm_func(k_val: int):
def _init_local_top_k(local_top_k, local_top_k_index):
for t in range(k_val):
T.buffer_store(local_top_k, T.min_value(dtype), indices=[t])
for t in range(k_val):
T.buffer_store(local_top_k_index, t, indices=[t])

def _process_value(x, local_top_k, local_top_k_index, vi, vk):
if_frames = [T.If(x[vi, vk] > local_top_k[i]) for i in range(k_val)]
then_frames = [T.Then() for _ in range(k_val)]
else_frames = [T.Else() for _ in range(k_val - 1)]
for i in range(k_val):
if_frames[i].__enter__() # pylint: disable=unnecessary-dunder-call
with then_frames[i]:
for j in range(k_val - 1, i, -1):
T.buffer_store(local_top_k, local_top_k[j - 1], indices=[j])
T.buffer_store(local_top_k_index, local_top_k_index[j - 1], indices=[j])
T.buffer_store(local_top_k, x[vi, vk], indices=[i])
T.buffer_store(local_top_k_index, vk, indices=[i])
if i != k_val - 1:
else_frames[i].__enter__() # pylint: disable=unnecessary-dunder-call

for i in range(k_val - 1, -1, -1):
if i != k_val - 1:
else_frames[i].__exit__(None, None, None)
if_frames[i].__exit__(None, None, None)

@T.prim_func(private=True)
def topk_softmax_norm_func(
var_x: T.handle,
var_out: T.handle,
var_out_index: T.handle,
) -> None:
T.func_attr({"tir.noalias": True, "tir.is_scheduled": True})
batch_size = T.int64()
x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)
out = T.match_buffer(var_out, (batch_size, k_val), dtype)
out_index = T.match_buffer(var_out_index, (batch_size, k_val), index_dtype)
local_top_k = T.alloc_buffer((k_val,), dtype=dtype, scope="local")
local_top_k_index = T.alloc_buffer((k_val,), dtype=index_dtype, scope="local")
for io in T.thread_binding(0, T.ceildiv(batch_size, TX), "blockIdx.x"):
for ii in T.thread_binding(0, TX, "threadIdx.x"):
with T.block("top_k"):
vi = T.axis.spatial(batch_size, io * TX + ii)
T.where(io * TX + ii < batch_size)
with T.block("init"):
_init_local_top_k(local_top_k, local_top_k_index)
for k in range(num_local_experts):
with T.block("update"):
vk = T.axis.remap("S", [k])
_process_value(x, local_top_k, local_top_k_index, vi, vk)
for j in T.unroll(k_val):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = local_top_k[vj]
out_index[vi, vj] = local_top_k_index[vj]

return topk_softmax_norm_func

if norm_topk_prob:
return op.tensor_ir_op(
top2_softmax_norm_func,
"top2_softmax",
_get_topk_softmax_norm_func(k),
f"top{k}_softmax",
args=[x],
out=(
Tensor.placeholder([batch_size, 2], dtype),
Tensor.placeholder([batch_size, 2], index_dtype),
),
)
if k == 4 and not norm_topk_prob:
expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype)
return op.tensor_ir_op(
top4_softmax_norm_func,
"top4_softmax",
args=[expert_score],
out=(
Tensor.placeholder([batch_size, 4], dtype),
Tensor.placeholder([batch_size, 4], index_dtype),
Tensor.placeholder([batch_size, k], dtype),
Tensor.placeholder([batch_size, k], index_dtype),
),
)
if norm_topk_prob:
# Compute topk first and then softmax to avoid extra re-normalize
expert_score, expert_indices = op.topk(
x, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype
)
expert_score = op.softmax(expert_score.astype("float32"), axis=-1).astype(dtype)
else:
expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype)
expert_score, expert_indices = op.topk(
expert_score, k, axis=-1, ret_type="both", largest=True, dtype=index_dtype
)
return expert_score, expert_indices

expert_score = op.softmax(x.astype("float32"), axis=-1).astype(dtype)
return op.tensor_ir_op(
_get_topk_softmax_norm_func(k),
f"top{k}_softmax",
args=[expert_score],
out=(
Tensor.placeholder([batch_size, k], dtype),
Tensor.placeholder([batch_size, k], index_dtype),
),
)


def moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor:
Expand Down

0 comments on commit 3463ab7

Please sign in to comment.