diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0408bb73c7f2da..803a0003868018 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -168,31 +168,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`LlamaLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`LlamaRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`LlamaDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use " - "`LlamaRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -255,6 +230,126 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + **kwargs + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +LLAMA_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -284,9 +379,6 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = LlamaRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, @@ -297,7 +389,7 @@ def forward( use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -310,16 +402,7 @@ def forward( key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -327,151 +410,17 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2(LlamaAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + attention_type = "eager" else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + attention_type = self.config._attn_implementation - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - **kwargs, + attn_output, attn_weights = LLAMA_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions, **kwargs ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) if not output_attentions: @@ -480,119 +429,12 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaSdpaAttention(LlamaAttention): - """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) - cos, sin = self.rotary_emb(value_states, position_ids) - else: - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -LLAMA_ATTENTION_CLASSES = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, -} - - class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -608,7 +450,7 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: