From 1f348b85459be2b12f9e86be95ef5a7179f641cf Mon Sep 17 00:00:00 2001 From: Jan Kaniecki Date: Mon, 12 Aug 2024 14:54:04 +0200 Subject: [PATCH] Reimplement silu_and_mul for mixtral (#167) * Reimplement silu and mul in mixtral * Typo fix --- vllm/hpu/ops.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index bd737917cb919..3748eb3544dd1 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -16,13 +16,6 @@ PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') -def silu_and_mul(output, input): - d = input.shape[-1] // 2 - silu = torch.nn.SiLU().to(input.device) - x, y = torch.split(input, d, dim=-1) - output.copy_(silu(x) * y) - - def fetch_from_cache(cache, blocks, permutations): return [ cache.index_select(0, blocks[:, i]).permute(permutations) @@ -81,12 +74,9 @@ def paged_attention_v1(query, return attn_weights.squeeze(-2) -def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor: +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - silu_and_mul(out, x) - return out + return F.silu(x[..., :d]) * x[..., d:] def static_fused_moe(hidden_states, w1, w2, score, topk): @@ -111,13 +101,10 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): htorch.core.mark_step() for expert_idx in range(num_experts): - padded_weight = padded_weights[expert_idx] - current_state_static = hidden_states.reshape(-1, D) - w_output = silu_and_mul_wrapper( - torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1))) + w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1)) + w_output = silu_and_mul(w_output) w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - current_hidden_states_static = w_output * padded_weight - final_hidden_states += current_hidden_states_static + final_hidden_states += w_output * padded_weights[expert_idx] htorch.core.mark_step() return final_hidden_states.view(-1, D)