Skip to content

Commit

Permalink
Reimplement silu_and_mul for mixtral (opendatahub-io#167)
Browse files Browse the repository at this point in the history
* Reimplement silu and mul in mixtral

* Typo fix
  • Loading branch information
jkaniecki authored Aug 12, 2024
1 parent 37ca17f commit 1f348b8
Showing 1 changed file with 5 additions and 18 deletions.
23 changes: 5 additions & 18 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@
PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1')


def silu_and_mul(output, input):
d = input.shape[-1] // 2
silu = torch.nn.SiLU().to(input.device)
x, y = torch.split(input, d, dim=-1)
output.copy_(silu(x) * y)


def fetch_from_cache(cache, blocks, permutations):
return [
cache.index_select(0, blocks[:, i]).permute(permutations)
Expand Down Expand Up @@ -81,12 +74,9 @@ def paged_attention_v1(query,
return attn_weights.squeeze(-2)


def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor:
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(out, x)
return out
return F.silu(x[..., :d]) * x[..., d:]


def static_fused_moe(hidden_states, w1, w2, score, topk):
Expand All @@ -111,13 +101,10 @@ def static_fused_moe(hidden_states, w1, w2, score, topk):
htorch.core.mark_step()

for expert_idx in range(num_experts):
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, D)
w_output = silu_and_mul_wrapper(
torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1)))
w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1))
w_output = silu_and_mul(w_output)
w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1))
current_hidden_states_static = w_output * padded_weight
final_hidden_states += current_hidden_states_static
final_hidden_states += w_output * padded_weights[expert_idx]
htorch.core.mark_step()

return final_hidden_states.view(-1, D)
Expand Down

0 comments on commit 1f348b8

Please sign in to comment.