Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix graph breaks in Mixtral (#65) #1705

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

"""PyTorch Mixtral model."""

import contextlib
import math
import os
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -74,18 +73,12 @@
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None

try:
from habana_frameworks.torch.hpu import sdp_kernel

SDPContext = True
except ImportError:
SDPContext = False

deepspeed_available = is_deepspeed_available()
logger = logging.get_logger(__name__)


def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
if q.device.type == "hpu" and FusedRoPE is not None:
return apply_customized_rope_module(q, k, cos, sin, position_ids, training)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
Expand All @@ -97,7 +90,7 @@ def gaudi_mixtral_rmsnorm_forward(self, hidden_states):
The only differences are:
- override RMSNorm with Habana fused RMSNorm
"""
if hidden_states.device.type == "hpu" and FusedRMSNorm:
if hidden_states.device.type == "hpu" and FusedRMSNorm is not None:
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
Expand Down Expand Up @@ -305,7 +298,7 @@ def forward(
else:
past_key_value = None

if FusedSDPA:
if FusedSDPA is not None:
if query_states.dtype != key_states.dtype:
key_states = key_states.type(query_states.dtype)
value_states = value_states.type(query_states.dtype)
Expand All @@ -322,12 +315,17 @@ def forward(
)
htcore.mark_step()
else:
with (
sdp_kernel(enable_recompute=flash_attention_recompute) if SDPContext else contextlib.nullcontext()
):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
attn_output = FusedSDPA.apply(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
"None",
flash_attention_recompute,
)
else:
query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv(
query_states, key_states, value_states, attention_mask, self.num_key_value_groups
Expand All @@ -351,7 +349,7 @@ def forward(

attn_output = self.o_proj(attn_output)

if not output_attentions or FusedSDPA:
if not output_attentions or FusedSDPA is not None:
attn_weights = None

return attn_output, attn_weights, past_key_value
Expand All @@ -376,7 +374,7 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) ->
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
if deepspeed_available and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
Expand Down Expand Up @@ -424,7 +422,7 @@ def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) -
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

if is_deepspeed_available() and (not self.training):
if deepspeed_available and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
Expand All @@ -450,7 +448,7 @@ def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) -
experts_min=0,
experts_max=7,
)
if is_deepspeed_available() and (not self.training):
if deepspeed_available and (not self.training):
from deepspeed import comm as dist

if dist.is_initialized():
Expand Down