diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index c11d7a277a..3ce3858c80 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -20,7 +20,6 @@ """PyTorch Mixtral model.""" -import contextlib import math import os from typing import List, Optional, Tuple, Union @@ -74,18 +73,12 @@ print("Not using HPU fused kernel for apply_rotary_pos_emb") FusedRoPE = None -try: - from habana_frameworks.torch.hpu import sdp_kernel - - SDPContext = True -except ImportError: - SDPContext = False - +deepspeed_available = is_deepspeed_available() logger = logging.get_logger(__name__) def apply_customized_rope(q, k, cos, sin, position_ids, training=True): - if q.device.type == "hpu" and FusedRoPE: + if q.device.type == "hpu" and FusedRoPE is not None: return apply_customized_rope_module(q, k, cos, sin, position_ids, training) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) @@ -97,7 +90,7 @@ def gaudi_mixtral_rmsnorm_forward(self, hidden_states): The only differences are: - override RMSNorm with Habana fused RMSNorm """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: + if hidden_states.device.type == "hpu" and FusedRMSNorm is not None: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype @@ -305,7 +298,7 @@ def forward( else: past_key_value = None - if FusedSDPA: + if FusedSDPA is not None: if query_states.dtype != key_states.dtype: key_states = key_states.type(query_states.dtype) value_states = value_states.type(query_states.dtype) @@ -322,12 +315,17 @@ def forward( ) htcore.mark_step() else: - with ( - sdp_kernel(enable_recompute=flash_attention_recompute) if SDPContext else contextlib.nullcontext() - ): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + attn_output = FusedSDPA.apply( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + "None", + flash_attention_recompute, + ) else: query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -351,7 +349,7 @@ def forward( attn_output = self.o_proj(attn_output) - if not output_attentions or FusedSDPA: + if not output_attentions or FusedSDPA is not None: attn_weights = None return attn_output, attn_weights, past_key_value @@ -376,7 +374,7 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - if is_deepspeed_available() and (not self.training): + if deepspeed_available and (not self.training): from deepspeed import comm as dist if dist.is_initialized(): @@ -424,7 +422,7 @@ def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) - # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - if is_deepspeed_available() and (not self.training): + if deepspeed_available and (not self.training): from deepspeed import comm as dist if dist.is_initialized(): @@ -450,7 +448,7 @@ def gaudi_mixtral_block_dynamic_moe_forward(self, hidden_states: torch.Tensor) - experts_min=0, experts_max=7, ) - if is_deepspeed_available() and (not self.training): + if deepspeed_available and (not self.training): from deepspeed import comm as dist if dist.is_initialized():