From 79cb53c1d078d333977b51b211becb861f41f9ed Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 28 Nov 2024 08:01:54 +0100 Subject: [PATCH 01/97] refactor LlamaAttention --- .../models/llama/modeling_llama.py | 367 +++++++----------- 1 file changed, 132 insertions(+), 235 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8e06098b04c6..803a00038680 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -230,6 +230,126 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward(config, query, key, value, mask, **_kwargs): + key_states = repeat_kv(key, config.num_key_value_groups) + value_states = repeat_kv(value, config.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + + if config.attn_logit_softcapping is not None: + attn_weights = attn_weights / config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * config.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **kwargs): + if mask is not None: + seq_len = mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + + query_states = query.transpose(1, 2) + key_states = key.transpose(1, 2) + value_states = value.transpose(1, 2) + + dropout_rate = config.attention_dropout if config.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + seq_len, + dropout=dropout_rate, + softmax_scale=config.scaling, + is_causal=config.is_causal, + sliding_window=config.sliding_window, + use_top_left_mask=config._flash_attn_uses_top_left_mask, + **kwargs + ) + + return attn_output, None + + +def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): + def tanh_softcap(score, b, h, q_idx, kv_idx): + soft_cap = config.attn_logit_softcapping + score = soft_cap * torch.tanh(score / soft_cap) + if mask is not None: + return score + mask[b][0][q_idx][kv_idx] + return score + + attn_output = flex_attention( + query, + key, + value, + score_mod=tanh_softcap, + enable_gqa=True, + scale=config.scaling, + return_lse=output_attentions, + ) + if not output_attentions: + return attn_output, None + else: + return attn_output[0], attn_output[1] + + +def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): + key = repeat_kv(key, config.num_key_value_groups) + value = repeat_kv(value, config.num_key_value_groups) + + causal_mask = mask + if mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query.device.type == "cuda" and causal_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=config.attention_dropout if config.training else 0.0, + is_causal=is_causal, + scale=config.scaling, + ) + return attn_output, None + + +LLAMA_ATTENTION_FUNCTION = { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "eager": eager_attention_forward, + "sdpa": sdpa_attention_forward, +} + + class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -268,8 +388,8 @@ def forward( output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -290,26 +410,14 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: + attention_type = "eager" + else: + attention_type = self.config._attn_implementation - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output, attn_weights = LLAMA_ATTENTION_FUNCTION[attention_type]( + self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions, **kwargs + ) attn_output = attn_output.reshape(bsz, q_len, -1) @@ -321,223 +429,12 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaFlashAttention2(LlamaAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - **kwargs, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaSdpaAttention(LlamaAttention): - """ - Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from LlamaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -LLAMA_ATTENTION_CLASSES = { - "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, - "sdpa": LlamaSdpaAttention, -} - - class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -552,8 +449,8 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: From 4bb485b45176083c5b77f573089476141b4a4f83 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 14:48:44 +0100 Subject: [PATCH 02/97] minimal changes --- .../integrations/flash_attention.py | 36 ++ .../integrations/flex_attention.py | 28 ++ .../integrations/sdpa_attention.py | 39 +++ src/transformers/modeling_utils.py | 246 ++++++++----- .../models/llama/modeling_llama.py | 330 ++++-------------- .../models/mistral/modular_mistral.py | 182 ++++++++++ src/transformers/models/olmo/modular_olmo.py | 117 +++++++ src/transformers/models/phi/modular_phi.py | 168 +++++++++ 8 files changed, 795 insertions(+), 351 deletions(-) create mode 100644 src/transformers/integrations/flash_attention.py create mode 100644 src/transformers/integrations/flex_attention.py create mode 100644 src/transformers/integrations/sdpa_attention.py create mode 100644 src/transformers/models/mistral/modular_mistral.py create mode 100644 src/transformers/models/olmo/modular_olmo.py create mode 100644 src/transformers/models/phi/modular_phi.py diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py new file mode 100644 index 000000000000..2189baef8158 --- /dev/null +++ b/src/transformers/integrations/flash_attention.py @@ -0,0 +1,36 @@ +import torch + +from ..modeling_flash_attention_utils import _flash_attention_forward + + +def flash_attention_forward( + config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs +): + if attentions_mask is not None: + seq_len = attentions_mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + else: + seq_len = query.shape[1] + + dropout_rate = config.attention_dropout if training else 0.0 + + input_dtype = query.dtype + if input_dtype == torch.float32: + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_forward( + query, + key, + value, + attentions_mask, + seq_len, + config=config, + dropout=dropout_rate, + layer_idx=layer_idx, + **kwargs, + ) + + return attn_output, None diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py new file mode 100644 index 000000000000..9c309a9ad505 --- /dev/null +++ b/src/transformers/integrations/flex_attention.py @@ -0,0 +1,28 @@ +from ..utils import is_torch_greater_or_equal + + +if is_torch_greater_or_equal("2.5"): + from torch.nn.attention.flex_attention import flex_attention + + +def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + def causal_mod(score, b, h, q_idx, kv_idx): + if causal_mask is not None: + score += causal_mask[b][0][q_idx][kv_idx] + return score + + attn_output, attention_weights = flex_attention( + query, + key, + value, + score_mod=causal_mod, + enable_gqa=True, + scale=module.scaling, + return_lse=True, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py new file mode 100644 index 000000000000..0cf58f035ea9 --- /dev/null +++ b/src/transformers/integrations/sdpa_attention.py @@ -0,0 +1,39 @@ +import torch + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + is_causal = True if causal_mask is None and query.shape[1] > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=module.config.attention_dropout if module.training else 0.0, + is_causal=is_causal, + scale=module.scaling, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 22dd1b7ccea5..e5bb3fb1c8a4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -45,6 +45,9 @@ from .dynamic_module_utils import custom_object_save from .generation import CompileConfig, GenerationConfig, GenerationMixin from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .integrations.flash_attention import flash_attention_forward +from .integrations.flex_attention import flex_attention_forward +from .integrations.sdpa_attention import sdpa_attention_forward from .loss.loss_utils import LOSS_MAPPING from .pytorch_utils import ( # noqa: F401 Conv1D, @@ -1487,7 +1490,10 @@ def _autoset_attn_implementation( message += ( ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' ) - raise ValueError(message + ".") + if config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + pass + else: + raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal @@ -1525,10 +1531,10 @@ def _autoset_attn_implementation( config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - config = cls._check_and_enable_sdpa( - config, - hard_check_only=False if requested_attn_implementation is None else True, - ) + # config = cls._check_and_enable_sdpa( + # config, + # hard_check_only=False if requested_attn_implementation is None else True, + # ) if ( torch.version.hip is not None @@ -1539,6 +1545,8 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) + elif config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + pass elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None else: @@ -2509,91 +2517,16 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - """ - Activates gradient checkpointing for the current model. + for layer in list(self.modules()): + if isinstance(layer, GradientCheckpointLayer): + layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs) - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of - the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - - Args: - gradient_checkpointing_kwargs (dict, *optional*): - Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. - """ - if not self.supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {"use_reentrant": True} - - gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) - - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` method - _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters - - if not _is_using_old_format: - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) - else: - self.apply(partial(self._set_gradient_checkpointing, value=True)) - logger.warning( - "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." - "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." - ) - - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() - - def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): - is_gradient_checkpointing_set = False - - # Apply it on the top-level module in case the top-level modules supports it - # for example, LongT5Stack inherits from `PreTrainedModel`. - if hasattr(self, "gradient_checkpointing"): - self._gradient_checkpointing_func = gradient_checkpointing_func - self.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - for module in self.modules(): - if hasattr(module, "gradient_checkpointing"): - module._gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - if not is_gradient_checkpointing_set: - raise ValueError( - f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" - " `gradient_checkpointing` to modules of the model that uses checkpointing." - ) - - def gradient_checkpointing_disable(self): - """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if self.supports_gradient_checkpointing: - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` methid - _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters - if not _is_using_old_format: - self._set_gradient_checkpointing(enable=False) - else: - logger.warning( - "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." - "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." - ) - self.apply(partial(self._set_gradient_checkpointing, value=False)) - - if getattr(self, "_hf_peft_config_loaded", False): - self.disable_input_require_grads() + @property + def gradient_checkpointing(self, gradient_checkpointing_kwargs=None): + for layer in list(self.modules()): + if isinstance(layer, GradientCheckpointLayer): + return layer.gradient_checkpointing + return False @property def is_gradient_checkpointing(self) -> bool: @@ -5631,3 +5564,138 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): files_content[filename].append(device_map[weight_name]) return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] + + +ALL_ATTENTION_FUNCTIONS: Dict[str, Dict[str, Callable]] = {} + +ALL_ATTENTION_FUNCTIONS.update( + { + "flash_attention_2": flash_attention_forward, + "flex_attention": flex_attention_forward, + "sdpa": sdpa_attention_forward, + } +) + + +class GradientCheckpointLayer(torch.nn.Module): + def __init__(self, *args, **kwargs): + self.gradient_checkpointing = False + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs): + """ + Adjust the behavior of the inherited class by overriding `__call__`. + + Automatically handles gradient checkpointing based on flags in the provided arguments. + """ + # Extract necessary flags and arguments + gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr( + self, "gradient_checkpointing", False + ) + training = self.training + + if gradient_checkpointing and training: + # Use gradient checkpointing + return self._apply_gradient_checkpointing(*args, **kwargs) + else: + # Default behavior: call the original `forward` method + return self.forward(*args, **kwargs) + + def _apply_gradient_checkpointing(self, *args, **kwargs): + """ + Apply gradient checkpointing using the appropriate function. + + By default, uses `torch.utils.checkpoint.checkpoint`. + """ + + # Assume `self.forward` is compatible with checkpointing + def wrapped_forward(): + return self.forward(*args, **kwargs) + + return self._gradient_checkpointing_func(wrapped_forward) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + self.gradient_checkpointing = True + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + + def gradient_checkpointing_disable(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=False) + else: + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 803a00038680..59162b4ec15b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -49,6 +49,7 @@ logging, replace_return_docstrings, ) +from ...utils.generic import validate_config_kwargs from .configuration_llama import LlamaConfig @@ -93,31 +94,14 @@ def __init__( config: Optional[LlamaConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`LlamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -230,177 +214,61 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) -def eager_attention_forward(config, query, key, value, mask, **_kwargs): - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling - - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -def flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **kwargs): - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - **kwargs - ) - - return attn_output, None - - -def flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - return attn_output, None - else: - return attn_output[0], attn_output[1] - - -def sdpa_attention_forward(config, query, key, value, mask, **_kwargs): - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - return attn_output, None - - -LLAMA_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} - - class LlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -410,26 +278,24 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - attention_type = "eager" - else: - attention_type = self.config._attn_implementation + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output, attn_weights = LLAMA_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions, **kwargs + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - -class LlamaDecoderLayer(nn.Module): +class LlamaDecoderLayer(GradientCheckpointLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -452,28 +318,6 @@ def forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -653,10 +497,6 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") - # Initialize weights and apply final processing self.post_init() @@ -700,31 +540,19 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -733,42 +561,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -778,18 +589,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py new file mode 100644 index 000000000000..f60badcce8d3 --- /dev/null +++ b/src/transformers/models/mistral/modular_mistral.py @@ -0,0 +1,182 @@ +import torch +import torch.utils.checkpoint + +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ..llama.modeling_llama import ( + LlamaForMultipleChoice, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, +) +from .configuration_mistral import MistralConfig + + +class MistralModel(LlamaModel): + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_cache: bool, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: MistralConfig, + past_key_values: Cache, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`MistralConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class MistralForTokenClassification(LlamaForTokenClassification): + pass + + +class MistralForSequenceClassification(LlamaForSequenceClassification): + pass + + +class MistralForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +class MistralForMultipleChoice(LlamaForMultipleChoice): + pass diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py new file mode 100644 index 000000000000..fc548766eb92 --- /dev/null +++ b/src/transformers/models/olmo/modular_olmo.py @@ -0,0 +1,117 @@ +import torch +import torch.utils.checkpoint + +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ..llama.modeling_llama import ( + LlamaForMultipleChoice, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaMLP, + LlamaAttention,apply_rotary_pos_emb, eager_attention_forward +) +from .configuration_olmo import OlmoConfig +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple, Callable +class OlmoLayerNorm(nn.Module): + """LayerNorm but with no learnable weight or bias.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.normalized_shape = (hidden_size,) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to( + orig_dtype + ) + + +class OlmoRMSNorm(LlamaRMSNorm): + pass + + +class OlmoMLP(LlamaMLP): + pass + +class OlmoAttention(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class OlmoDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: OlmoConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) + self.mlp = OlmoMLP(config) + self.input_layernorm = OlmoLayerNorm(config.hidden_size) + self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + + +class OlmoForTokenClassification(LlamaForTokenClassification): + pass + + +class OlmoForSequenceClassification(LlamaForSequenceClassification): + pass + + +class OlmoForQuestionAnswering(LlamaForQuestionAnswering): + pass + + +class OlmoForMultipleChoice(LlamaForMultipleChoice): + pass diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py new file mode 100644 index 000000000000..189c69397602 --- /dev/null +++ b/src/transformers/models/phi/modular_phi.py @@ -0,0 +1,168 @@ +from .configuration_phi import PhiConfig +from ..llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb, LlamaMLP, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaForQuestionAnswering +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +import torch.nn as nn + +import math +import torch +from typing import Optional, Tuple, Callable +from transformers.cache_utils import Cache + + +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, layer_head_mask=None, **_kwargs): + key_states = repeat_kv(key_states, attention_class.num_key_value_groups) + value_states = repeat_kv(value_states, attention_class.num_key_value_groups) + + # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow + attn_weights = torch.matmul( + query.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) + ) / math.sqrt(attention_class.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=attention_class.attention_dropout, training=attention_class.training) + + attn_output = torch.matmul(attn_weights, value_states) + return attn_output, attn_weights + + +class PhiAttention(LlamaAttention): + + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.dense = self.o_proj + self.qk_layernorm = config.qk_layernorm + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + + def forward(self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + cos, sin = position_embeddings + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + +class PhiMLP(LlamaMLP): + pass +class PhiDecoderLayer(nn.Module): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__() + self.self_attn = PhiAttention(config, layer_idx=layer_idx) + self.mlp = PhiMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + attn_outputs = self.resid_dropout(attn_outputs) + + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +class PhiForCausalLM(LlamaForCausalLM): + pass + +class PhiForSequenceClassification(LlamaForSequenceClassification): + pass + +class PhiForTokenClassification(LlamaForTokenClassification): + pass + +class PhiForQuestionAnswering(LlamaForQuestionAnswering): + pass + + From f3709070210a4a4ad696a71dc7e654591ff2d13e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 14:59:19 +0100 Subject: [PATCH 03/97] fix llama --- src/transformers/models/llama/modeling_llama.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 59162b4ec15b..ccb15fef2ef5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -17,8 +17,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -28,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -37,7 +36,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -45,11 +44,9 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) -from ...utils.generic import validate_config_kwargs from .configuration_llama import LlamaConfig @@ -85,12 +82,7 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[LlamaConfig] = None, ): super().__init__() From d3ef5397d750d5a11de5905a250a3e9fde273b8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 14:59:38 +0100 Subject: [PATCH 04/97] update --- src/transformers/modeling_utils.py | 4 -- src/transformers/models/olmo/modular_olmo.py | 24 ++++---- src/transformers/models/phi/modular_phi.py | 58 +++++++++----------- 3 files changed, 41 insertions(+), 45 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e5bb3fb1c8a4..c3b784b82f6e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -174,10 +174,6 @@ def is_local_dist_rank_0(): if is_peft_available(): from .utils import find_adapter_config_file - -SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") - - TORCH_INIT_FUNCTIONS = { "uniform_": nn.init.uniform_, "normal_": nn.init.normal_, diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index fc548766eb92..b6bf326d014c 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -1,23 +1,27 @@ +from typing import Callable, Optional, Tuple + import torch +import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint -from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, LlamaForMultipleChoice, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, - LlamaDecoderLayer, - LlamaRMSNorm, LlamaMLP, - LlamaAttention,apply_rotary_pos_emb, eager_attention_forward + LlamaRMSNorm, + apply_rotary_pos_emb, + eager_attention_forward, ) from .configuration_olmo import OlmoConfig -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -import torch.nn as nn -import torch.nn.functional as F -from typing import Optional, Tuple, Callable + + class OlmoLayerNorm(nn.Module): """LayerNorm but with no learnable weight or bias.""" @@ -39,8 +43,8 @@ class OlmoRMSNorm(LlamaRMSNorm): class OlmoMLP(LlamaMLP): pass + class OlmoAttention(LlamaAttention): - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 189c69397602..a88c0fbfd453 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,37 +1,28 @@ -from .configuration_phi import PhiConfig -from ..llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb, LlamaMLP, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaForQuestionAnswering -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS -import torch.nn as nn - import math -import torch -from typing import Optional, Tuple, Callable -from transformers.cache_utils import Cache - +from typing import Callable, Optional, Tuple -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, layer_head_mask=None, **_kwargs): - key_states = repeat_kv(key_states, attention_class.num_key_value_groups) - value_states = repeat_kv(value_states, attention_class.num_key_value_groups) +import torch +import torch.nn as nn - # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow - attn_weights = torch.matmul( - query.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) - ) / math.sqrt(attention_class.head_dim) +from transformers.cache_utils import Cache - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ..llama.modeling_llama import ( + LlamaAttention, + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaMLP, + apply_rotary_pos_emb, + eager_attention_forward # copied from Llama +) +from .configuration_phi import PhiConfig - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=attention_class.attention_dropout, training=attention_class.training) - attn_output = torch.matmul(attn_weights, value_states) - return attn_output, attn_weights class PhiAttention(LlamaAttention): - def __init__(self, config: PhiConfig, layer_idx: int): super().__init__(config, layer_idx) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) @@ -44,8 +35,9 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.k_layernorm = nn.LayerNorm( config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - - def forward(self, + + def forward( + self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, @@ -59,7 +51,6 @@ def forward(self, key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if self.qk_layernorm: query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) @@ -102,8 +93,11 @@ def forward(self, attn_output = self.o_proj(attn_output) return attn_output, attn_weights + class PhiMLP(LlamaMLP): pass + + class PhiDecoderLayer(nn.Module): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() @@ -152,17 +146,19 @@ def forward( outputs += (present_key_value,) return outputs - + + class PhiForCausalLM(LlamaForCausalLM): pass + class PhiForSequenceClassification(LlamaForSequenceClassification): pass + class PhiForTokenClassification(LlamaForTokenClassification): pass + class PhiForQuestionAnswering(LlamaForQuestionAnswering): pass - - From 45eac5821d4dab198d09ee064d4e8970a8aac792 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 15:08:52 +0100 Subject: [PATCH 05/97] modular gemmas --- .../models/gemma/modular_gemma.py | 457 +----------------- .../models/gemma2/modeling_gemma2.py | 409 +--------------- .../models/gemma2/modular_gemma2.py | 185 +------ 3 files changed, 23 insertions(+), 1028 deletions(-) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 778ef7e19b65..1c048d04f415 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -354,454 +354,11 @@ def extra_repr(self): ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) - -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class GemmaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = GemmaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GemmaSdpaAttention(GemmaAttention): - """ - Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GemmaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention): - """ - Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GemmaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -GEMMA_ATTENTION_CLASSES = { - "eager": GemmaAttention, - "flash_attention_2": GemmaFlashAttention2, - "sdpa": GemmaSdpaAttention, -} - - class GemmaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__(config) - self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - self.mlp = GemmaMLP(config) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + super().__init__(config, layer_idx) + self.input_layernorm = GemmaRMSNorm(config.hidden_size) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size) class GemmaPreTrainedModel(LlamaPreTrainedModel): @@ -809,14 +366,6 @@ class GemmaPreTrainedModel(LlamaPreTrainedModel): class GemmaModel(LlamaModel): - def __init__(self, config: GemmaConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet! - self.post_init() def forward( self, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 288913697f26..765024963a73 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -43,6 +43,7 @@ is_torch_greater_or_equal, logging, replace_return_docstrings, + ) from .configuration_gemma2 import Gemma2Config @@ -56,127 +57,8 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "google/gemma2-7b" -_CONFIG_FOR_DOC = "Gemma2Config" - -class Gemma2RMSNorm(nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.zeros(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()) - # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - output = output * (1.0 + self.weight.float()) - return output.type_as(x) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.eps}" - - -class Gemma2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class Gemma2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: +def gemma_eager_attention_forward(config, query, key, value, mask, **_kwargs): key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -198,15 +80,7 @@ def eager_attention_forward( return attn_output, attn_weights -def flash_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - target_dtype: torch.dtype = torch.float16, - **_kwargs, -) -> Tuple[torch.Tensor, None]: +def gemma_flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] @@ -243,15 +117,7 @@ def flash_attention_forward( return attn_output, None -def flex_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - output_attentions: bool = False, - **_kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +def gemma_flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) @@ -277,14 +143,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): return attn_output, attn_weights -def sdpa_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, None]: +def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs): key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups) @@ -316,132 +175,11 @@ def sdpa_attention_forward( return attn_output, None -GEMMA2_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, -} - - -class Gemma2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = config.query_pre_attn_scalar**-0.5 - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attn_logit_softcapping = config.attn_logit_softcapping - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = Gemma2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "flex_attention" - else: - attention_type = self.config._attn_implementation - - attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "flash_attention_2" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class Gemma2SdpaAttention(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "sdpa" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) -class Gemma2DecoderLayer(nn.Module): +class Gemma2DecoderLayer(GemmaDecoderLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() - self.hidden_size = config.hidden_size - self.config = config - self.is_sliding = not bool(layer_idx % 2) - self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) - self.mlp = Gemma2MLP(config) - self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -504,141 +242,6 @@ def forward( return outputs - -GEMMA2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Gemma2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", - GEMMA2_START_DOCSTRING, -) -class Gemma2PreTrainedModel(PreTrainedModel): - config_class = Gemma2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Gemma2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = False - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - - -GEMMA2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - @add_start_docstrings( "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", GEMMA2_START_DOCSTRING, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5e04fe1b63a3..4ad151a398d1 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -40,7 +40,7 @@ GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, - apply_rotary_pos_emb, + GemmaAttention, repeat_kv, ) @@ -213,14 +213,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass -def eager_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, torch.Tensor]: +def gemma_eager_attention_forward(config, query, key, value, mask, **_kwargs): key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -242,15 +235,7 @@ def eager_attention_forward( return attn_output, attn_weights -def flash_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - target_dtype: torch.dtype = torch.float16, - **_kwargs, -) -> Tuple[torch.Tensor, None]: +def gemma_flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): if mask is not None: seq_len = mask.shape[1] query = query[:, :, :seq_len] @@ -287,15 +272,7 @@ def flash_attention_forward( return attn_output, None -def flex_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - output_attentions: bool = False, - **_kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: +def gemma_flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): def tanh_softcap(score, b, h, q_idx, kv_idx): soft_cap = config.attn_logit_softcapping score = soft_cap * torch.tanh(score / soft_cap) @@ -321,14 +298,7 @@ def tanh_softcap(score, b, h, q_idx, kv_idx): return attn_output, attn_weights -def sdpa_attention_forward( - config: Gemma2Config, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: Optional[torch.Tensor], - **_kwargs, -) -> Tuple[torch.Tensor, None]: +def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs): key = repeat_kv(key, config.num_key_value_groups) value = repeat_kv(value, config.num_key_value_groups) @@ -360,132 +330,20 @@ def sdpa_attention_forward( return attn_output, None -GEMMA2_ATTENTION_FUNCTION = { - "flash_attention_2": flash_attention_forward, - "flex_attention": flex_attention_forward, - "eager": eager_attention_forward, - "sdpa": sdpa_attention_forward, +ALL_ATTENTION_FUNCTION = { + "gemma_flash_attention_2": gemma_flash_attention_forward, + "gemma_flex_attention": gemma_flex_attention_forward, + "gemma_eager": gemma_eager_attention_forward, + "gemma_sdpa": gemma_sdpa_attention_forward, } +class Gemma2Attention(GemmaAttention): + pass -class Gemma2Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = config.query_pre_attn_scalar**-0.5 - self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None - self.attn_logit_softcapping = config.attn_logit_softcapping - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = Gemma2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "sliding_window": self.sliding_window, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if output_attentions and self.config._attn_implementation in ["sdpa", "flash_attention_2"]: - logger.warning_once("Setting `attention_type` to `flex_attention` because `output_attentions=True`") - attention_type = "flex_attention" - else: - attention_type = self.config._attn_implementation - - attn_output, attn_weights = GEMMA2_ATTENTION_FUNCTION[attention_type]( - self, query_states, key_states, value_states, attention_mask, output_attentions=output_attentions - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Gemma2FlashAttention2(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "flash_attention_2" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class Gemma2SdpaAttention(Gemma2Attention): - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.config._attn_implementation = "sdpa" - logger.warning_once( - "The `Gemma2FlashAttention2` class is deprecated in favor of simply modifying the `config._attn_implementation`" - "attribute of the `GemmaAttention` class! It will be removed in v4.48" - ) - - -class Gemma2DecoderLayer(nn.Module): +class Gemma2DecoderLayer(GemmaDecoderLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() - self.hidden_size = config.hidden_size - self.config = config self.is_sliding = not bool(layer_idx % 2) - self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) - self.mlp = Gemma2MLP(config) - self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -550,22 +408,7 @@ def forward( class Gemma2PreTrainedModel(GemmaPreTrainedModel): - _supports_quantized_cache = False - - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): - """ - Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models. - SDPA reduces the model performance on Gemma2 because of the logits softcapping. - """ - config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only) - - # if using the default path -> swap sdpa by eager - if not hard_check_only and config._attn_implementation == "sdpa": - config._attn_implementation = "eager" - - return config - + pass class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config): From e52af4944c7b73290eb1b179adcdbbe1010cd098 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 15:13:01 +0100 Subject: [PATCH 06/97] modular nits --- .../models/gemma/modular_gemma.py | 9 +- .../models/gemma2/modular_gemma2.py | 2 +- .../models/olmo2/modular_olmo2.py | 228 ++---------------- src/transformers/models/phi/modular_phi.py | 5 +- 4 files changed, 22 insertions(+), 222 deletions(-) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 1c048d04f415..f00e040f6272 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import sentencepiece as spm @@ -21,24 +20,18 @@ import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig -from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging from ..llama.modeling_llama import ( LlamaDecoderLayer, - LlamaFlashAttention2, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, - LlamaPreTrainedModel, - apply_rotary_pos_emb, - repeat_kv, ) from ..llama.tokenization_llama import LlamaTokenizer diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 4ad151a398d1..f8c5ff9adeba 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -33,6 +33,7 @@ logging, ) from ..gemma.modeling_gemma import ( + GemmaAttention, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, @@ -40,7 +41,6 @@ GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, - GemmaAttention, repeat_kv, ) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 393d17c59c1a..6cee9d8c35a5 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -1,29 +1,26 @@ -import math -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch from torch import nn from ...cache_utils import Cache +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging -from ..llama.modeling_llama import LlamaRMSNorm +from ...utils import is_flash_attn_2_available, logging +from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward from ..olmo.configuration_olmo import OlmoConfig from ..olmo.modeling_olmo import ( OlmoAttention, OlmoDecoderLayer, - OlmoFlashAttention2, OlmoForCausalLM, OlmoModel, OlmoPreTrainedModel, - OlmoSdpaAttention, apply_rotary_pos_emb, - repeat_kv, ) if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass logger = logging.get_logger(__name__) @@ -184,96 +181,17 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states)) key_states = self.k_norm(self.k_proj(hidden_states)) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - return attn_output, attn_weights, past_key_value - - -class Olmo2FlashAttention2(OlmoFlashAttention2, Olmo2Attention): - """ - OLMo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - Olmo2Attention.__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -283,129 +201,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (OlmoRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2SdpaAttention(OlmoSdpaAttention, Olmo2Attention): - # Adapted from Olmo2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Olmo2Model is using Olmo2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - bsz, q_len, _ = hidden_states.size() - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attn_weights # The OLMo2 layers are identical to those of the OLMo model except: diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index a88c0fbfd453..c4036bee9452 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,4 +1,3 @@ -import math from typing import Callable, Optional, Tuple import torch @@ -15,13 +14,11 @@ LlamaForTokenClassification, LlamaMLP, apply_rotary_pos_emb, - eager_attention_forward # copied from Llama + eager_attention_forward, # copied from Llama ) from .configuration_phi import PhiConfig - - class PhiAttention(LlamaAttention): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__(config, layer_idx) From 5ed37aee7f0b112a9f8a4ad5cf922f524d92d2c6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 15:21:58 +0100 Subject: [PATCH 07/97] modular updates --- .../modular-transformers/modeling_dummy.py | 467 ++------- .../modeling_multimodal1.py | 469 ++------- .../modeling_new_task_model.py | 34 +- .../modular-transformers/modeling_super.py | 398 ++------ src/transformers/configuration_utils.py | 3 +- src/transformers/models/aria/modeling_aria.py | 459 ++------- .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/glm/modeling_glm.py | 113 +-- .../models/mistral/modeling_mistral.py | 796 ++++----------- .../models/mistral/modular_mistral.py | 3 - src/transformers/models/phi/modeling_phi.py | 954 ++++++------------ 11 files changed, 876 insertions(+), 2822 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 6172c9acfd21..8fc0a16b89d2 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_dummy.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -13,17 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_dummy import DummyConfig @@ -53,40 +47,18 @@ def extra_repr(self): class DummyRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[DummyConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`DummyRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -199,144 +171,61 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class DummyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DummyConfig, layer_idx: Optional[int] = None): + def __init__(self, config: DummyConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class DummyFlashAttention2(DummyAttention): - """ - Dummy flash attention module. This module inherits from `DummyAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -346,167 +235,29 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DummyRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class DummySdpaAttention(DummyAttention): - """ - Dummy attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `DummyAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from DummyAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "DummyModel is using DummySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -DUMMY_ATTENTION_CLASSES = { - "eager": DummyAttention, - "flash_attention_2": DummyFlashAttention2, - "sdpa": DummySdpaAttention, -} + return attn_output, attn_weights -class DummyDecoderLayer(nn.Module): +class DummyDecoderLayer(GradientCheckpointLayer): def __init__(self, config: DummyConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = DUMMY_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = DummyAttention(config=config, layer_idx=layer_idx) self.mlp = DummyMLP(config) self.input_layernorm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -521,31 +272,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -725,10 +454,6 @@ def __init__(self, config: DummyConfig): self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DummyRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") - # Initialize weights and apply final processing self.post_init() @@ -772,31 +497,19 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -805,42 +518,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -850,18 +546,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 562e7dcab2b9..e4ee9eeb07df 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_multimodal1.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -13,17 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_multimodal1 import Multimodal1TextConfig @@ -53,40 +47,18 @@ def extra_repr(self): class Multimodal1TextRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[Multimodal1TextConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Multimodal1TextRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -199,144 +171,61 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class Multimodal1TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Multimodal1TextConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Multimodal1TextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Multimodal1TextFlashAttention2(Multimodal1TextAttention): - """ - Multimodal1Text flash attention module. This module inherits from `Multimodal1TextAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -346,169 +235,29 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Multimodal1TextRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Multimodal1TextSdpaAttention(Multimodal1TextAttention): - """ - Multimodal1Text attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Multimodal1TextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - # Adapted from Multimodal1TextAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Multimodal1TextModel is using Multimodal1TextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -MULTIMODAL1_TEXT_ATTENTION_CLASSES = { - "eager": Multimodal1TextAttention, - "flash_attention_2": Multimodal1TextFlashAttention2, - "sdpa": Multimodal1TextSdpaAttention, -} - - -class Multimodal1TextDecoderLayer(nn.Module): +class Multimodal1TextDecoderLayer(GradientCheckpointLayer): def __init__(self, config: Multimodal1TextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MULTIMODAL1_TEXT_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) + self.self_attn = Multimodal1TextAttention(config=config, layer_idx=layer_idx) self.mlp = Multimodal1TextMLP(config) self.input_layernorm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -523,31 +272,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -727,10 +454,6 @@ def __init__(self, config: Multimodal1TextConfig): self.norm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Multimodal1TextRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") - # Initialize weights and apply final processing self.post_init() @@ -774,31 +497,19 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -807,42 +518,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -852,18 +546,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index d303d328e887..477d084b1d93 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -10,7 +10,7 @@ import torch from torch import nn -from ...cache_utils import Cache, StaticCache +from ...cache_utils import Cache, HybridCache, StaticCache from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -253,7 +253,14 @@ def tie_weights(self): return self.language_model.tie_weights() def _update_causal_mask( - self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_ids=None, + inputs_embeds=None, + is_training: bool = False, ): if self.config.text_config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: @@ -261,11 +268,13 @@ def _update_causal_mask( return None using_static_cache = isinstance(past_key_values, StaticCache) - dtype = inputs_embeds.dtype - min_dtype = torch.finfo(dtype).min - sequence_length = inputs_embeds.shape[1] + min_dtype = torch.finfo(self.dtype).min + inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] + sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + target_length = past_key_values.get_max_cache_shape() else: target_length = ( attention_mask.shape[-1] @@ -278,7 +287,7 @@ def _update_causal_mask( return attention_mask causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device ) # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below if sequence_length != 1: @@ -288,7 +297,7 @@ def _update_causal_mask( causal_mask[:, :sequence_length] = 0.0 causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] @@ -317,7 +326,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor): image_outputs = self.vision_tower(pixel_values) selected_image_feature = image_outputs.last_hidden_state image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features / (self.config.hidden_size**0.5) + image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features @add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING) @@ -414,6 +423,7 @@ def prepare_inputs_for_generation( token_type_ids=None, use_cache=True, num_logits_to_keep=None, + labels=None, **kwargs, ): # Overwritten -- custom `position_ids` and `pixel_values` handling @@ -433,12 +443,16 @@ def prepare_inputs_for_generation( # position_ids in NewTaskModel are 1-indexed if model_inputs.get("position_ids") is not None: model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values - + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + causal_mask = self._update_causal_mask( + attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training + ) + model_inputs["attention_mask"] = causal_mask return model_inputs def resize_token_embeddings( diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 79e5ab15a5ed..0bec83258924 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_super.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -13,23 +12,15 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward from .configuration_super import SuperConfig -logger = logging.get_logger(__name__) - - class SuperRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -53,40 +44,18 @@ def extra_repr(self): class SuperRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[SuperConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`SuperRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -199,144 +168,61 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class SuperAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: SuperConfig, layer_idx: Optional[int] = None): + def __init__(self, config: SuperConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - return attn_output, attn_weights, past_key_value - - -class SuperFlashAttention2(SuperAttention): - """ - Super flash attention module. This module inherits from `SuperAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -346,167 +232,29 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (SuperRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class SuperSdpaAttention(SuperAttention): - """ - Super attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `SuperAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from SuperAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "SuperModel is using SuperSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - return attn_output, None, past_key_value - - -SUPER_ATTENTION_CLASSES = { - "eager": SuperAttention, - "flash_attention_2": SuperFlashAttention2, - "sdpa": SuperSdpaAttention, -} - -class SuperDecoderLayer(nn.Module): +class SuperDecoderLayer(GradientCheckpointLayer): def __init__(self, config: SuperConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = SUPER_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = SuperAttention(config=config, layer_idx=layer_idx) self.mlp = SuperMLP(config) self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -521,31 +269,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -725,10 +451,6 @@ def __init__(self, config: SuperConfig): self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SuperRotaryEmbedding(config=config) - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") - # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index a04b7bd6aa1b..638bc88b1425 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -37,11 +37,10 @@ download_url, extract_commit_hash, is_remote_url, - is_timm_config_dict, is_torch_available, logging, ) - +from .utils.generic import is_timm_config_dict logger = logging.get_logger(__name__) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c3e3e424a4ba..d1546ae220ec 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -18,24 +18,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -478,144 +476,61 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class AriaTextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: AriaTextConfig, layer_idx: Optional[int] = None): + def __init__(self, config: AriaTextConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class AriaTextFlashAttention2(AriaTextAttention): - """ - AriaText flash attention module. This module inherits from `AriaTextAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -625,162 +540,24 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (AriaTextRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value - - -class AriaTextSdpaAttention(AriaTextAttention): - """ - AriaText attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `AriaTextAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from AriaTextAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "AriaTextModel is using AriaTextSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -ARIA_TEXT_ATTENTION_CLASSES = { - "eager": AriaTextAttention, - "flash_attention_2": AriaTextFlashAttention2, - "sdpa": AriaTextSdpaAttention, -} - - -class AriaTextDecoderLayer(nn.Module): +class AriaTextDecoderLayer(GradientCheckpointLayer): """ Aria Text Decoder Layer. @@ -797,7 +574,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ARIA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = AriaTextAttention(config=config, layer_idx=layer_idx) self.mlp = AriaTextMoELayer(config) self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -811,31 +588,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -953,40 +708,18 @@ def _init_weights(self, module): class AriaTextRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[AriaTextConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`AriaTextRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -1136,8 +869,6 @@ def __init__(self, config: AriaTextConfig): self.norm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = AriaTextRotaryEmbedding(config=config) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -1182,31 +913,19 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -1215,42 +934,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1260,18 +962,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index db25591eaa35..b3ec2ef157c9 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -30,12 +30,12 @@ CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, - is_timm_config_dict, is_timm_local_checkpoint, is_torchvision_available, is_vision_available, logging, ) +from .utils.generic import is_timm_config_dict from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index b4a292d69de9..c01e2f14838f 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -36,7 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, @@ -476,19 +476,12 @@ def forward( return attn_output, None, past_key_value -GLM_ATTENTION_CLASSES = { - "eager": GlmAttention, - "flash_attention_2": GlmFlashAttention2, - "sdpa": GlmSdpaAttention, -} - - -class GlmDecoderLayer(nn.Module): +class GlmDecoderLayer(GradientCheckpointLayer): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GLM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = GlmAttention(config=config, layer_idx=layer_idx) self.mlp = GlmMLP(config) self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -503,31 +496,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -711,8 +682,6 @@ def __init__(self, config: GlmConfig): base=config.rope_theta, ) self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") # Initialize weights and apply final processing self.post_init() @@ -757,31 +726,19 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -790,42 +747,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -835,18 +775,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 6ed8178ed982..c983ea15e433 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1,66 +1,37 @@ -# coding=utf-8 -# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Mistral model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mistral/modular_mistral.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mistral.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, - CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_mistral import MistralConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" +_CHECKPOINT_FOR_DOC = "meta-mistral/Mistral-2-7b-hf" _CONFIG_FOR_DOC = "MistralConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral class MistralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -82,24 +53,55 @@ def extra_repr(self): class MistralRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[MistralConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -107,10 +109,30 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -# Copied from transformers.models.llama.modeling_llama.rotate_half +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -118,7 +140,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -146,21 +167,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class MistralMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -173,257 +179,63 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class MistralAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = MistralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MistralFlashAttention2(MistralAttention): - """ - Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral -# TODO(joao): add me back asap :) -class MistralSdpaAttention(MistralAttention): - """ - Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MistralAttention.forward def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -431,56 +243,29 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights -MISTRAL_ATTENTION_CLASSES = { - "eager": MistralAttention, - "flash_attention_2": MistralFlashAttention2, - "sdpa": MistralSdpaAttention, -} - - -# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Mistral, LLAMA->MISTRAL -# TODO(joao): add me back asap :) -class MistralDecoderLayer(nn.Module): +class MistralDecoderLayer(GradientCheckpointLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) self.mlp = MistralMLP(config) self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -495,27 +280,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -529,6 +296,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -576,10 +344,11 @@ class MistralPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MistralDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True _supports_static_cache = True def _init_weights(self, module): @@ -663,7 +432,7 @@ def _init_weights(self, module): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`, + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @@ -690,10 +459,9 @@ def __init__(self, config: MistralConfig): self.layers = nn.ModuleList( [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MistralRotaryEmbedding(config=config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -716,93 +484,67 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -812,18 +554,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -977,15 +714,26 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - +@add_start_docstrings( + """ + The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MISTRAL_START_DOCSTRING, +) +class MistralForTokenClassification(MistralPreTrainedModel): def __init__(self, config): super().__init__(config) + self.num_labels = config.num_labels self.model = MistralModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @@ -996,75 +744,35 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, TokenClassifierOutput]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MistralForCausalLM - - >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( - input_ids=input_ids, + input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, @@ -1073,36 +781,22 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device - shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits, labels, self.config) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( + return TokenClassifierOutput( loss=loss, logits=logits, - past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @@ -1123,7 +817,6 @@ def forward( """, MISTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL class MistralForSequenceClassification(MistralPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1213,95 +906,6 @@ def forward( ) -@add_start_docstrings( - """ - The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - MISTRAL_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mistral, LLAMA->MISTRAL -class MistralForTokenClassification(MistralPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = MistralModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - @add_start_docstrings( """ The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like @@ -1309,24 +913,22 @@ def forward( """, MISTRAL_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForQuestionAnswering with Llama->Mistral,LLAMA->MISTRAL,transformer->model class MistralForQuestionAnswering(MistralPreTrainedModel): - base_model_prefix = "model" + base_model_prefix = "transformer" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mistral def __init__(self, config): super().__init__(config) - self.model = MistralModel(config) + self.transformer = MistralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.transformer.embed_tokens def set_input_embeddings(self, value): - self.model.embed_tokens = value + self.transformer.embed_tokens = value @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( @@ -1355,7 +957,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index f60badcce8d3..33d6154a5a4c 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -177,6 +177,3 @@ class MistralForSequenceClassification(LlamaForSequenceClassification): class MistralForQuestionAnswering(LlamaForQuestionAnswering): pass - -class MistralForMultipleChoice(LlamaForMultipleChoice): - pass diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8e60798e857f..73653b404c6a 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1,153 +1,48 @@ -# coding=utf-8 -# Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""PyTorch Phi model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/phi/modular_phi.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_phi.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from packaging import version -from torch import nn -from torch.nn import CrossEntropyLoss +import torch.nn as nn + +from transformers.cache_utils import Cache from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, + QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - get_torch_version, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_phi import PhiConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "microsoft/phi-1" +_CHECKPOINT_FOR_DOC = "meta-phi/Phi-2-7b-hf" _CONFIG_FOR_DOC = "PhiConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi -class PhiRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[PhiConfig] = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`PhiRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -155,7 +50,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -183,23 +77,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi -class PhiMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) - return hidden_states - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -212,41 +89,48 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class PhiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): + def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) - + self.dense = self.o_proj self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( @@ -256,284 +140,26 @@ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True ) - self.rotary_emb = PhiRotaryEmbedding(config=self.config) - def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.qk_layernorm: - query_states = self.q_layernorm(query_states) - key_states = self.k_layernorm(key_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_ndims], - query_states[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_ndims], - key_states[..., self.rotary_ndims :], - ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "partial_rotation_size": self.rotary_ndims, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow - attn_weights = torch.matmul( - query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) - ) / math.sqrt(self.head_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.dense(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class PhiFlashAttention2(PhiAttention): - """ - Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # PhiFlashAttention2 attention does not support output_attentions - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.qk_layernorm: - query_states = self.q_layernorm(query_states) - key_states = self.k_layernorm(key_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_ndims], - query_states[..., self.rotary_ndims :], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_ndims], - key_states[..., self.rotary_ndims :], - ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "partial_rotation_size": self.rotary_ndims, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_dropout = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. - - if query_states.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=attn_dropout, - softmax_scale=None, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.dense(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class PhiSdpaAttention(PhiAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - - """ - SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `PhiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from PhiAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "PhiModel is using PhiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not " - "support `output_attentions=True`. Falling back to the manual attention implementation, but specifying " - "the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can " - 'be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) if self.qk_layernorm: query_states = self.q_layernorm(query_states) key_states = self.k_layernorm(key_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - # Partial rotary embedding query_rot, query_pass = ( query_states[..., : self.rotary_ndims], @@ -551,61 +177,47 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - cache_kwargs = { - "sin": sin, - "cos": cos, - "partial_rotation_size": self.rotary_ndims, - "cache_position": cache_position, - } + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.dense(attn_output) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - return attn_output, None, past_key_value +class PhiMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] -PHI_ATTENTION_CLASSES = { - "eager": PhiAttention, - "flash_attention_2": PhiFlashAttention2, - "sdpa": PhiSdpaAttention, -} + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class PhiDecoderLayer(nn.Module): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__() - self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.self_attn = PhiAttention(config, layer_idx=layer_idx) self.mlp = PhiMLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.resid_dropout = nn.Dropout(config.resid_pdrop) @@ -622,32 +234,6 @@ def forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): - input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range - `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -678,6 +264,91 @@ def forward( return outputs +class PhiRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + PhiRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class PhiRotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[PhiConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + PHI_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -704,12 +375,12 @@ class PhiPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PhiDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True - _supports_static_cache = True _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range @@ -816,17 +487,12 @@ def __init__(self, config: PhiConfig): self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.embed_dropout = nn.Dropout(config.embd_pdrop) self.layers = nn.ModuleList( [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.norm = PhiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = PhiRotaryEmbedding(config=config) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" - - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -842,54 +508,40 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -897,7 +549,6 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - inputs_embeds = self.embed_dropout(inputs_embeds) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -906,64 +557,42 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - output_attentions, - use_cache, - past_key_values, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -1030,7 +659,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1087,40 +715,37 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True def __init__(self, config): super().__init__(config) self.model = PhiModel(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings def get_input_embeddings(self): return self.model.embed_tokens - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings def set_input_embeddings(self, value): self.model.embed_tokens = value - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings def get_output_embeddings(self): return self.lm_head - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder def set_decoder(self, decoder): self.model = decoder - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder def get_decoder(self): return self.model @@ -1131,7 +756,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1140,7 +765,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1161,18 +786,17 @@ def forward( ```python >>> from transformers import AutoTokenizer, PhiForCausalLM - >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1") - >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") + >>> model = PhiForCausalLM.from_pretrained("meta-phi/Phi-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-phi/Phi-2-7b-hf") - >>> prompt = "This is an example script ." + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str' + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1191,6 +815,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1199,7 +824,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1216,7 +841,7 @@ def forward( @add_start_docstrings( """ - The PhiModel with a sequence classification head on top (linear layer). + The Phi Model transformer with a sequence classification head on top (linear layer). [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. @@ -1229,7 +854,6 @@ def forward( """, PHI_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs class PhiForSequenceClassification(PhiPreTrainedModel): def __init__(self, config): super().__init__(config) @@ -1268,7 +892,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - model_outputs = self.model( + transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1279,7 +903,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - hidden_states = model_outputs[0] + hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: @@ -1307,44 +931,48 @@ def forward( loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) if not return_dict: - output = (pooled_logits,) + model_outputs[1:] + output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, - past_key_values=model_outputs.past_key_values, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, ) @add_start_docstrings( """ - PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for - Named-Entity-Recognition (NER) tasks. + The Phi Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. """, PHI_START_DOCSTRING, ) -# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs class PhiForTokenClassification(PhiPreTrainedModel): - def __init__(self, config: PhiConfig): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.model = PhiModel(config) - if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout - elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) - self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) @add_code_sample_docstrings( checkpoint=_CHECKPOINT_FOR_DOC, @@ -1354,16 +982,16 @@ def __init__(self, config: PhiConfig): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments, - ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., @@ -1372,38 +1000,118 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - model_outputs = self.model( + outputs = self.model( input_ids, - past_key_values=past_key_values, attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - - hidden_states = model_outputs[0] - hidden_states = self.dropout(hidden_states) - logits = self.classifier(hidden_states) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - batch_size, seq_length = labels.shape - loss_fct = CrossEntropyLoss() - loss = loss_fct( - logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) - ) + loss = self.loss_function(logits, labels, self.config) if not return_dict: - output = (logits,) + model_outputs[2:] + output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return TokenClassifierOutput( loss=loss, logits=logits, - hidden_states=model_outputs.hidden_states, - attentions=model_outputs.attentions, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Phi Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + PHI_START_DOCSTRING, +) +class PhiForQuestionAnswering(PhiPreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = PhiModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) From 38cafc1aba35c48cfcc45d3f489fb44ff6508df3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 15:23:01 +0100 Subject: [PATCH 08/97] nits --- .../models/gemma2/modeling_gemma2.py | 450 +++++--- src/transformers/models/olmo/modeling_olmo.py | 963 ++++++++---------- src/transformers/models/olmo/modular_olmo.py | 2 - 3 files changed, 732 insertions(+), 683 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 765024963a73..4a018e5e0a61 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,6 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Optional, Tuple, Union import torch @@ -38,148 +39,230 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal, - is_torch_greater_or_equal, logging, replace_return_docstrings, - ) from .configuration_gemma2 import Gemma2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +logger = logging.get_logger(__name__) -if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention -logger = logging.get_logger(__name__) +class Gemma2RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) -def gemma_eager_attention_forward(config, query, key, value, mask, **_kwargs): - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling - - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping - attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def gemma_flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, - ) + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - return attn_output, None - - -def gemma_flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - attn_weights = None - else: - attn_output, attn_weights = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs): - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +class Gemma2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class Gemma2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Gemma2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.scaling = 1 / math.sqrt(config.head_dim) + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = Gemma2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value class Gemma2DecoderLayer(GemmaDecoderLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() + self.is_sliding = not bool(layer_idx % 2) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -242,6 +325,127 @@ def forward( return outputs + +GEMMA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Gemma2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", + GEMMA2_START_DOCSTRING, +) +class Gemma2PreTrainedModel(PreTrainedModel): + config_class = Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +GEMMA2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + @add_start_docstrings( "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", GEMMA2_START_DOCSTRING, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 8b40c41e34dc..ec5a1cf2f468 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1,59 +1,35 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch OLMo model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo/modular_olmo.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import torch +import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache -from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_olmo import OlmoConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "meta-olmo/Olmo-2-7b-hf" _CONFIG_FOR_DOC = "OlmoConfig" @@ -71,70 +47,42 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) -ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm) - - -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo -# TODO(joao): add me back asap :) -class OlmoRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): +class OlmoRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + OlmoRMSNorm is equivalent to T5LayerNorm + """ super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding): - """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin +class OlmoMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] -class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding): - """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -142,7 +90,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -170,22 +117,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class OlmoMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -198,282 +129,57 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class OlmoAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo - # TODO(joao): add me back asap :) - def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None): + def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = OlmoRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = OlmoLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = OlmoDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class OlmoFlashAttention2(OlmoAttention): - """ - OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (OlmoRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class OlmoSdpaAttention(OlmoAttention): - """ - OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `OlmoAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from OlmoAttention.forward def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "OlmoModel is using OlmoSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) @@ -484,9 +190,9 @@ def forward( key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -496,62 +202,33 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -OLMO_ATTENTION_CLASSES = { - "eager": OlmoAttention, - "flash_attention_2": OlmoFlashAttention2, - "sdpa": OlmoSdpaAttention, -} + return attn_output, attn_weights -class OlmoDecoderLayer(nn.Module): +class OlmoDecoderLayer(GradientCheckpointLayer): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = OLMO_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) self.mlp = OlmoMLP(config) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) - # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward - # TODO(joao): add me back asap :) def forward( self, hidden_states: torch.Tensor, @@ -561,27 +238,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -595,6 +254,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -616,6 +276,71 @@ def forward( return outputs +class OlmoRotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[OlmoConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + OLMO_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -637,7 +362,6 @@ def forward( "The bare Olmo Model outputting raw hidden-states without any specific head on top.", OLMO_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmo class OlmoPreTrainedModel(PreTrainedModel): config_class = OlmoConfig base_model_prefix = "model" @@ -758,8 +482,8 @@ def __init__(self, config: OlmoConfig): self.layers = nn.ModuleList( [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = OlmoLayerNorm(config.hidden_size) - self.gradient_checkpointing = False + self.norm = OlmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = OlmoRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() @@ -771,8 +495,6 @@ def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) - # copied from transformers.models.llama.modeling_llama.LlamaModel.forward - # TODO(joao): add me back asap :) def forward( self, input_ids: torch.LongTensor = None, @@ -785,6 +507,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -805,25 +528,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -831,45 +541,33 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -879,20 +577,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() - # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, @@ -959,7 +651,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1016,15 +707,26 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo -class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - +@add_start_docstrings( + """ + The Olmo Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + OLMO_START_DOCSTRING, +) +class OlmoForTokenClassification(OlmoPreTrainedModel): def __init__(self, config): super().__init__(config) + self.num_labels = config.num_labels self.model = OlmoModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @@ -1035,104 +737,249 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.model.embed_tokens = value - def get_output_embeddings(self): - return self.lm_head + @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) - def set_decoder(self, decoder): - self.model = decoder + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) - def get_decoder(self): - return self.model + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Olmo Model transformer with a sequence classification head on top (linear layer). + + [`OlmoForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + OLMO_START_DOCSTRING, +) +class OlmoForSequenceClassification(OlmoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OlmoModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, - input_ids: torch.LongTensor = None, + input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - Returns: +@add_start_docstrings( + """ +The Olmo Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + OLMO_START_DOCSTRING, +) +class OlmoForQuestionAnswering(OlmoPreTrainedModel): + base_model_prefix = "transformer" - Example: + def __init__(self, config): + super().__init__(config) + self.transformer = OlmoModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) - ```python - >>> from transformers import AutoTokenizer, OlmoForCausalLM + # Initialize weights and apply final processing + self.post_init() - >>> model = OlmoForCausalLM.from_pretrained("allenai/OLMo-1B-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf") + def get_input_embeddings(self): + return self.transformer.embed_tokens - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' - ``` + @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, + outputs = self.transformer( + input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, ) - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( + return QuestionAnsweringModelOutput( loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, + start_logits=start_logits, + end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index b6bf326d014c..10478537e4df 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -117,5 +117,3 @@ class OlmoForQuestionAnswering(LlamaForQuestionAnswering): pass -class OlmoForMultipleChoice(LlamaForMultipleChoice): - pass From a862eacd7515d87008eb67079c4e989269e54a66 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 15:33:01 +0100 Subject: [PATCH 09/97] simplify --- .../models/gemma/modeling_gemma.py | 509 ++++----------- src/transformers/models/olmo/modeling_olmo.py | 135 +++- src/transformers/models/olmo/modular_olmo.py | 5 +- .../models/olmo2/configuration_olmo2.py | 1 + .../models/olmo2/modeling_olmo2.py | 598 +++++------------- .../models/starcoder2/modeling_starcoder2.py | 34 +- 6 files changed, 439 insertions(+), 843 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index b3253fdd5614..12b6cd81451b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -19,8 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,19 +28,20 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -74,85 +74,20 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -class GemmaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding): - """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj def rotate_half(x): @@ -201,241 +136,63 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class GemmaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): + def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = GemmaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GemmaSdpaAttention(GemmaAttention): - """ - Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GemmaAttention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class GemmaFlashAttention2(GemmaAttention): - """ - Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -443,76 +200,33 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GemmaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -GEMMA_ATTENTION_CLASSES = { - "eager": GemmaAttention, - "flash_attention_2": GemmaFlashAttention2, - "sdpa": GemmaSdpaAttention, -} + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights -class GemmaDecoderLayer(nn.Module): +class GemmaDecoderLayer(GradientCheckpointLayer): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + self.mlp = GemmaMLP(config) - self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = GemmaRMSNorm(config.hidden_size) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size) def forward( self, @@ -523,27 +237,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -557,6 +253,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -578,6 +275,71 @@ def forward( return outputs +class GemmaRotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[GemmaConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + GEMMA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -720,10 +482,7 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + self.rotary_emb = GemmaRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index ec5a1cf2f468..478343d2d24b 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -12,10 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, @@ -23,7 +25,14 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_olmo import OlmoConfig @@ -707,6 +716,130 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = OlmoModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OlmoForCausalLM + + >>> model = OlmoForCausalLM.from_pretrained("meta-olmo/Olmo-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo/Olmo-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @add_start_docstrings( """ The Olmo Model transformer with a token classification head on top (a linear layer on top of the hidden-states diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 10478537e4df..8f397c512a75 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -10,7 +10,7 @@ from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, - LlamaForMultipleChoice, + LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, @@ -104,6 +104,9 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) +class OlmoForCausalLM(LlamaForCausalLM): + pass + class OlmoForTokenClassification(LlamaForTokenClassification): pass diff --git a/src/transformers/models/olmo2/configuration_olmo2.py b/src/transformers/models/olmo2/configuration_olmo2.py index 144520f87ed7..83c3263de1f5 100644 --- a/src/transformers/models/olmo2/configuration_olmo2.py +++ b/src/transformers/models/olmo2/configuration_olmo2.py @@ -5,6 +5,7 @@ # modular_olmo2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 6c35587f1f14..9931346e1c29 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -4,35 +4,31 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -from torch import nn +import torch.nn as nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_olmo2 import Olmo2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "Olmo2Config" @@ -56,66 +52,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo2 -# TODO(joao): add me back asap :) -class Olmo2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -class Olmo2LinearScalingRotaryEmbedding(Olmo2RotaryEmbedding): - """Olmo2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -class Olmo2DynamicNTKScalingRotaryEmbedding(Olmo2RotaryEmbedding): - """Olmo2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -162,73 +98,49 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class Olmo2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo2 - # TODO(joao): add me back asap :) def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - self._init_rope() + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = Olmo2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = Olmo2LinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = Olmo2DynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - def forward( self, hidden_states: torch.Tensor, @@ -240,15 +152,16 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_norm(self.q_proj(hidden_states)) key_states = self.k_norm(self.k_proj(hidden_states)) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -258,220 +171,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2FlashAttention2(Olmo2Attention): - """ - Olmo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - - OLMo2 flash attention module. This module inherits from `Olmo2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (OlmoRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Olmo2SdpaAttention(Olmo2Attention): - """ - Olmo2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Olmo2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Olmo2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Olmo2Model is using Olmo2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - bsz, q_len, _ = hidden_states.size() - query_states = self.q_norm(self.q_proj(hidden_states)) - key_states = self.k_norm(self.k_proj(hidden_states)) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attn_weights class Olmo2MLP(nn.Module): @@ -480,35 +194,26 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj -OLMO2_ATTENTION_CLASSES = { - "eager": Olmo2Attention, - "flash_attention_2": Olmo2FlashAttention2, - "sdpa": Olmo2SdpaAttention, -} - - -class Olmo2DecoderLayer(nn.Module): +class Olmo2DecoderLayer(GradientCheckpointLayer): def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = OLMO2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - + self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) self.mlp = Olmo2MLP(config) self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward - # TODO(joao): add me back asap :) def forward( self, hidden_states: torch.Tensor, @@ -520,25 +225,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states # Self Attention @@ -614,6 +300,71 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +class Olmo2RotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[Olmo2Config] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + OLMO2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -711,7 +462,7 @@ def __init__(self, config: Olmo2Config): [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False + self.rotary_emb = Olmo2RotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() @@ -723,8 +474,6 @@ def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) - # copied from transformers.models.llama.modeling_llama.LlamaModel.forward - # TODO(joao): add me back asap :) def forward( self, input_ids: torch.LongTensor = None, @@ -737,6 +486,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -757,25 +507,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -783,45 +520,33 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -831,18 +556,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -966,9 +686,12 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask -# TODO: re-enable check: Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO2,Llama->Olmo2 +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config: Olmo2Config): super().__init__(config) @@ -999,13 +722,12 @@ def get_decoder(self): @add_start_docstrings_to_model_forward(OLMO2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1014,7 +736,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1035,8 +757,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, Olmo2ForCausalLM - >>> model = Olmo2ForCausalLM.from_pretrained("allenai/Olmo2-1B-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo2-1B-hf") + >>> model = Olmo2ForCausalLM.from_pretrained("meta-olmo2/Olmo2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-olmo2/Olmo2-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1044,9 +766,8 @@ def forward( >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' - ``` - """ + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1065,6 +786,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1073,7 +795,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8047e23bb05b..ed1651996793 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -66,40 +66,18 @@ class Starcoder2RotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[Starcoder2Config] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Starcoder2RotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] From 5639b81a5d22e5907cee4179ac72233798c3a271 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 15:59:45 +0100 Subject: [PATCH 10/97] gpt2 --- src/transformers/models/gpt2/modeling_gpt2.py | 368 ++++-------------- 1 file changed, 70 insertions(+), 298 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 58143192c204..d3b89eeb2f7c 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -120,6 +120,46 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model +def eager_attention_forward(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + class GPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -180,45 +220,7 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - ) - - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) @@ -272,104 +274,18 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - - def _merge_heads(self, tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask - else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - if use_cache is True: - present = (key, value) - else: - present = None - - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) - else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) - - -class GPT2FlashAttention2(GPT2Attention): - """ - GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - bsz, _, _ = hidden_states.size() if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( @@ -383,184 +299,43 @@ def forward( else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - present = None - if use_cache is True: - present = (key, value) - - query_length = query.shape[2] - tgt_len = key.shape[2] - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) - key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - - attn_dropout = self.attn_dropout.p if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - if query.dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.c_proj.weight.dtype - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) - - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) - attn_output = self.c_proj(attn_weights_reshaped) - attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights_reshaped,) + query_states = query.view(hidden_shape).transpose(1, 2) + key_states = key.view(hidden_shape).transpose(1, 2) + value_states = value.view(hidden_shape).transpose(1, 2) - return outputs + cos, sin = position_embeddings + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) -class GPT2SdpaAttention(GPT2Attention): - """ - GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GPT2Attention` as the weights of the module stays untouched. The only changes are on the forward pass - to adapt to the SDPA API. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Idea adapted from transformers.models.bert.modeling_bert.BertSdpaSelfAttention.__init__ - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()`. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if output_attentions or head_mask is not None: - logger.warning_once( - "`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " - "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " - 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - bsz, q_len, _ = hidden_states.size() - - # Initial attention projections - is_cross_attention = encoder_hidden_states is not None - if is_cross_attention: - if not hasattr(self, "q_attn"): - raise ValueError( - "If class is used as cross attention, the weights `q_attn` have to be defined. " - "Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`." - ) - - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) - attention_mask = encoder_attention_mask + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - # Optional kv caching - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) - - present = None - if use_cache is True: - present = (key, value) - - # Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA - if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.attn_dropout.p if self.training else 0.0, - is_causal=is_causal, + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, ) - # Reshape outputs - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.embed_dim) - - # Final projection + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - return attn_output, present, None - + return attn_output, attn_weights + class GPT2MLP(nn.Module): def __init__(self, intermediate_size, config): @@ -579,22 +354,19 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention} - class GPT2Block(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size - attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = attention_class(config=config, layer_idx=layer_idx) + self.attn = GPT2Attention(config=config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: - self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.crossattention = GPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config) From 452d8edcddc120b49b35317c80cecb51d81757a3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 16:08:53 +0100 Subject: [PATCH 11/97] more modualr and fixes --- src/transformers/configuration_utils.py | 1 + .../models/auto/image_processing_auto.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 7 +- .../models/mistral/modular_mistral.py | 1 - .../models/qwen2/modeling_qwen2.py | 1457 +---------------- 5 files changed, 36 insertions(+), 1432 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 638bc88b1425..648877c8dce9 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -42,6 +42,7 @@ ) from .utils.generic import is_timm_config_dict + logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index b3ec2ef157c9..00b0c049c91d 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -35,7 +35,6 @@ is_vision_available, logging, ) -from .utils.generic import is_timm_config_dict from .auto_factory import _LazyAutoMapping from .configuration_auto import ( CONFIG_MAPPING_NAMES, @@ -43,6 +42,7 @@ model_type_to_module_name, replace_list_option_in_docstrings, ) +from .utils.generic import is_timm_config_dict logger = logging.get_logger(__name__) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index d3b89eeb2f7c..63e0d14b2735 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -23,7 +23,6 @@ import torch import torch.utils.checkpoint -from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -44,9 +43,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - get_torch_version, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -55,7 +52,7 @@ if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass logger = logging.get_logger(__name__) @@ -335,7 +332,7 @@ def forward( attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output, attn_weights - + class GPT2MLP(nn.Module): def __init__(self, intermediate_size, config): diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 33d6154a5a4c..2d36df948d3e 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -4,7 +4,6 @@ from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ..llama.modeling_llama import ( - LlamaForMultipleChoice, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f4875386253c..b956578da1f8 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -19,37 +19,15 @@ # limitations under the License. """PyTorch Qwen2 model.""" -import math -from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, - replace_return_docstrings, ) -from .configuration_qwen2 import Qwen2Config if is_flash_attn_2_available(): @@ -59,1413 +37,42 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B" -_CONFIG_FOR_DOC = "Qwen2Config" - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) +def qwen2_flash_attention_forward( + config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs +): + if attention_mask is not None: + seq_len = attention_mask.shape[1] + query = query[:, :, :seq_len] + value = value[:, :, :seq_len] + else: + seq_len = query.shape[1] - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + dropout_rate = config.attention_dropout if training else 0.0 - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: Optional[Qwen2Config] = None, + sliding_window = None + if ( + config.use_sliding_window + and getattr(config, "sliding_window", None) is not None + and layer_idx >= config.max_window_layers ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if ( - self.config.use_sliding_window - and getattr(self.config, "sliding_window", None) is not None - and self.layer_idx >= self.config.max_window_layers - ): - sliding_window = self.config.sliding_window - else: - sliding_window = None - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_window = config.sliding_window + + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + seq_len, + config=config, + dropout=dropout_rate, + layer_idx=layer_idx, + sliding_window=sliding_window, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen2RotaryEmbedding(config=config) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Qwen2 - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen2Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen2Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - _tp_plan = {"lm_head": "colwise_rep"} - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - QWEN2_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 -class Qwen2ForTokenClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - QWEN2_START_DOCSTRING, -) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2, MISTRAL->QWEN2 -class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): - base_model_prefix = "model" - - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2 - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + return attn_output, None - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output +ALL_ATTENTION_FUNCTIONS["qwen2_flash_attention2"] = qwen2_flash_attention_forward - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) +# TODO mask creation is also different for Qwen2 +# MAKE SURE THAT config._attn_implementation is automatically qwen2_flash_attention2 From 81a0b664ca3fe674573deedf76681dca7165b8fc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 16:31:25 +0100 Subject: [PATCH 12/97] granite --- .../modeling_my_new_model2.py | 488 ++++-------- src/transformers/modeling_utils.py | 2 + .../models/gemma2/modeling_gemma2.py | 195 ++--- .../models/granite/modeling_granite.py | 695 +++++------------- .../models/granite/modular_granite.py | 138 ++++ 5 files changed, 574 insertions(+), 944 deletions(-) create mode 100644 src/transformers/models/granite/modular_granite.py diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 189e090094c7..fa89cc368380 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -4,8 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_my_new_model2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -13,15 +12,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast -from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_my_new_model2 import MyNewModel2Config @@ -48,56 +44,20 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -class MyNewModel2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) - - @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class MyNewModel2MLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - if config.hidden_activation is None: - logger.warning_once( - "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n" - "MyNewModel2's activation function will be set to `gelu_pytorch_tanh`. Please, use\n" - "`config.hidden_activation` if you want to override this behaviour.\n" - "See https://github.com/huggingface/transformers/pull/29402 for more details." - ) - config.hidden_activation = "gelu_pytorch_tanh" - hidden_activation = config.hidden_activation - self.act_fn = ACT2FN[hidden_activation] + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj def rotate_half(x): @@ -146,241 +106,63 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class MyNewModel2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MyNewModel2Config, layer_idx: Optional[int] = None): + def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = MyNewModel2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class MyNewModel2SdpaAttention(MyNewModel2Attention): - """ - MyNewModel2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MyNewModel2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MyNewModel2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MyNewModel2Model is using MyNewModel2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -class MyNewModel2FlashAttention2(MyNewModel2Attention): - """ - MyNewModel2 flash attention module. This module inherits from `MyNewModel2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -388,78 +170,33 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (MyNewModel2RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - -MY_NEW_MODEL2_ATTENTION_CLASSES = { - "eager": MyNewModel2Attention, - "flash_attention_2": MyNewModel2FlashAttention2, - "sdpa": MyNewModel2SdpaAttention, -} + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights -class MyNewModel2DecoderLayer(nn.Module): +class MyNewModel2DecoderLayer(GradientCheckpointLayer): def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MY_NEW_MODEL2_ATTENTION_CLASSES[config._attn_implementation]( - config=config, layer_idx=layer_idx - ) + + self.self_attn = MyNewModel2Attention(config=config, layer_idx=layer_idx) + self.mlp = MyNewModel2MLP(config) - self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size) + self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size) def forward( self, @@ -470,27 +207,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -504,6 +223,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + hidden_states @@ -525,6 +245,71 @@ def forward( return outputs +class MyNewModel2RotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[MyNewModel2Config] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + MY_NEW_MODEL2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -667,10 +452,7 @@ def __init__(self, config: MyNewModel2Config): [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + self.rotary_emb = MyNewModel2RotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c3b784b82f6e..410774c5eb91 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1486,6 +1486,8 @@ def _autoset_attn_implementation( message += ( ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' ) + if cls.model_type + "_"+ config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + config._attn_implementation_internal = cls.model_type+ "_" + config._attn_implementation if config._attn_implementation in ALL_ATTENTION_FUNCTIONS: pass else: diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 4a018e5e0a61..b8297d8e6ce8 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -19,8 +19,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -28,13 +27,16 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -86,23 +88,55 @@ def forward(self, x): class Gemma2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[Gemma2Config] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -110,6 +144,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -159,68 +198,63 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class Gemma2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): + def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.scaling = 1 / math.sqrt(config.head_dim) - - if self.hidden_size % self.num_heads != 0: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = Gemma2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -228,35 +262,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights class Gemma2DecoderLayer(GemmaDecoderLayer): @@ -468,10 +488,7 @@ def __init__(self, config: Gemma2Config): [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") + self.rotary_emb = Gemma2RotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 8cd24265d9ed..20bb591bbda4 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/granite/modular_granite.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_granite.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. # @@ -13,29 +19,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -43,96 +43,9 @@ logger = logging.get_logger(__name__) - _CONFIG_FOR_DOC = "GraniteConfig" -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Granite -class GraniteRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - GraniteRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -ALL_LAYERNORM_LAYERS.append(GraniteRMSNorm) - - -class GraniteRotaryEmbedding(nn.Module): - def __init__(self, config: GraniteConfig): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half with Llama->Granite def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -140,7 +53,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb with Llama->Granite def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -168,24 +80,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class GraniteMLP(nn.Module): - # Copied from transformers.models.llama.modeling_llama.LlamaMLP.__init__ with Llama->Granite - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - # Copied from transformers.models.gemma.modeling_gemma.GemmaMLP.forward with Gemma->Granite - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv with Llama->Granite def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -198,6 +92,23 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class GraniteAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -205,247 +116,37 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.is_causal = True - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.attention_multiplier - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GraniteFlashAttention2(GraniteAttention): - """ - Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GraniteRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - softmax_scale=self.scaling, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GraniteSdpaAttention(GraniteAttention): - """ - Granite attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GraniteAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GraniteAttention.forward def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GraniteModel is using GraniteSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -455,47 +156,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -GRANITE_ATTENTION_CLASSES = { - "eager": GraniteAttention, - "flash_attention_2": GraniteFlashAttention2, - "sdpa": GraniteSdpaAttention, -} + return attn_output, attn_weights class GraniteDecoderLayer(nn.Module): @@ -503,7 +178,7 @@ def __init__(self, config: GraniteConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GRANITE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) self.mlp = GraniteMLP(config) self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -580,6 +255,91 @@ def forward( return outputs +class GraniteRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteRotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[GraniteConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + GRANITE_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -601,7 +361,6 @@ def forward( "The bare Granite Model outputting raw hidden-states without any specific head on top.", GRANITE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Granite class GranitePreTrainedModel(PreTrainedModel): config_class = GraniteConfig base_model_prefix = "model" @@ -723,17 +482,7 @@ def __init__(self, config: GraniteConfig): [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.gradient_checkpointing = False - - self.embedding_multiplier = config.embedding_multiplier - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - - # rope - self.rotary_emb = GraniteRotaryEmbedding(config) + self.rotary_emb = GraniteRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() @@ -757,6 +506,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -777,27 +527,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds * self.embedding_multiplier - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -805,7 +540,6 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - # embed positions hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -814,41 +548,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -858,18 +576,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -879,11 +592,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask @@ -906,7 +614,6 @@ def _update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_cache_shape() @@ -917,24 +624,17 @@ def _update_causal_mask( else past_seen_tokens + sequence_length + 1 ) - if attention_mask is not None and attention_mask.dim() == 4: - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -944,12 +644,12 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1006,10 +706,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Granite def __init__(self, config): super().__init__(config) self.model = GraniteModel(config) @@ -1052,6 +755,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1060,6 +765,11 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -1067,8 +777,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, GraniteForCausalLM - >>> model = GraniteForCausalLM.from_pretrained("ibm/PowerLM-3b") - >>> tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b") + >>> model = GraniteForCausalLM.from_pretrained("meta-granite/Granite-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-granite/Granite-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1096,26 +806,16 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits / self.config.logits_scaling + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -1128,12 +828,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py new file mode 100644 index 000000000000..f9707ac68d30 --- /dev/null +++ b/src/transformers/models/granite/modular_granite.py @@ -0,0 +1,138 @@ +# coding=utf-8 +# Copyright 2024 IBM and the HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_granite import GraniteConfig +from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM + +logger = logging.get_logger(__name__) + +class GraniteAttention(LlamaAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): + super().__init__() + self.scaling = config.attention_multiplier + + +class GraniteDecoderLayer(nn.Module): + def __init__(self, config: GraniteConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) + + self.mlp = GraniteMLP(config) + self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GraniteForCausalLM(LlamaForCausalLM): + pass \ No newline at end of file From bc72c3f54428023aca573d18c73ad937249f9589 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 16:38:37 +0100 Subject: [PATCH 13/97] modular modular modular --- .../models/gemma2/modular_gemma2.py | 1 + src/transformers/models/gpt2/modeling_gpt2.py | 4 +- .../models/granite/modular_granite.py | 27 +- .../models/mixtral/modular_mixtral.py | 570 ++++++++++++++++++ 4 files changed, 579 insertions(+), 23 deletions(-) create mode 100644 src/transformers/models/mixtral/modular_mixtral.py diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f8c5ff9adeba..923b3b8d2afe 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -41,6 +41,7 @@ GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, + GemmaDecoderLayer, repeat_kv, ) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 63e0d14b2735..d501b036d117 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -19,7 +19,7 @@ import os import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Callable import torch import torch.utils.checkpoint @@ -36,7 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary +from ...modeling_utils import PreTrainedModel, SequenceSummary, ALL_ATTENTION_FUNCTIONS from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index f9707ac68d30..bb3854d74db2 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -13,34 +13,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss - -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS + +from ...cache_utils import Cache from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, - replace_return_docstrings, ) +from ..llama.modeling_llama import LlamaAttention, LlamaForCausalLM from .configuration_granite import GraniteConfig -from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM + logger = logging.get_logger(__name__) @@ -135,4 +120,4 @@ def forward( class GraniteForCausalLM(LlamaForCausalLM): - pass \ No newline at end of file + pass diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py new file mode 100644 index 000000000000..8980189f43e3 --- /dev/null +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -0,0 +1,570 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mixtral model.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ..mistral.modeling_mistral import MistralAttention, MistralPreTrainedModel +from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaForCausalLM +from .configuration_mixtral import MixtralConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" +_CONFIG_FOR_DOC = "MixtralConfig" + + +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralRMSNorm(LlamaRMSNorm): + pass + +class MixtralAttention(MistralAttention): + pass + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MixtralAttention(config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + + + +class MixtralModel(MistralPreTrainedModel, MistralForCausalLM): + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + + +class MixtralForCausalLM(LlamaForCausalLM): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) From 48caa890c78e79a3a64f7a82f35d20fbc7a58376 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 12 Dec 2024 16:40:43 +0100 Subject: [PATCH 14/97] nits --- src/transformers/models/gemma2/modular_gemma2.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 4 ++-- src/transformers/models/granite/modular_granite.py | 9 ++++++++- src/transformers/models/mixtral/modular_mixtral.py | 7 ++----- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 923b3b8d2afe..9f7daa5f5124 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -34,6 +34,7 @@ ) from ..gemma.modeling_gemma import ( GemmaAttention, + GemmaDecoderLayer, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, @@ -41,7 +42,6 @@ GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, - GemmaDecoderLayer, repeat_kv, ) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index d501b036d117..3ed36095eca4 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -19,7 +19,7 @@ import os import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union, Callable +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -36,7 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, SequenceSummary, ALL_ATTENTION_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel, SequenceSummary from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index bb3854d74db2..42c36fdaadd0 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -23,12 +23,19 @@ from ...utils import ( logging, ) -from ..llama.modeling_llama import LlamaAttention, LlamaForCausalLM +from ..llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaMLP, LlamaRMSNorm from .configuration_granite import GraniteConfig logger = logging.get_logger(__name__) +class GraniteRMSNorm(LlamaRMSNorm): + pass + +class GraniteMLP(LlamaMLP): + pass + + class GraniteAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 8980189f43e3..2036775f1228 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -33,13 +33,10 @@ MoeModelOutputWithPast, ) from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, logging, - replace_return_docstrings, ) -from ..mistral.modeling_mistral import MistralAttention, MistralPreTrainedModel -from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaForCausalLM +from ..llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm +from ..mistral.modeling_mistral import MistralAttention, MistralPreTrainedModel, MistralForCausalLM from .configuration_mixtral import MixtralConfig From df68dd0daf5224ba9119072979dc6290b046436f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 13 Dec 2024 09:45:25 +0100 Subject: [PATCH 15/97] update --- .../models/gemma2/modeling_gemma2.py | 11 +- src/transformers/models/glm/modeling_glm.py | 306 +---- src/transformers/models/glm/modular_glm.py | 15 - .../models/granite/modeling_granite.py | 57 +- .../models/mixtral/modeling_mixtral.py | 1158 +++++------------ .../models/mixtral/modular_mixtral.py | 4 +- 6 files changed, 417 insertions(+), 1134 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index b8297d8e6ce8..73536bd0f840 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -35,7 +35,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, @@ -279,9 +279,16 @@ def forward( return attn_output, attn_weights -class Gemma2DecoderLayer(GemmaDecoderLayer): +class Gemma2DecoderLayer(GradientCheckpointLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size) self.is_sliding = not bool(layer_idx % 2) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index c01e2f14838f..f9170625a24d 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -29,20 +29,19 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -135,6 +134,23 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., 0::2] @@ -191,246 +207,35 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.is_causal = True + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = 1 / math.sqrt(self.head_dim) - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.view(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GlmFlashAttention2(GlmAttention): - """ - Glm flash attention module. This module inherits from `GlmAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (GlmRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - softmax_scale=self.scaling, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class GlmSdpaAttention(GlmAttention): - """ - Glm attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `GlmAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from GlmAttention.forward def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "GlmModel is using GlmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -440,40 +245,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - scale=self.scaling, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value + return attn_output, attn_weights class GlmDecoderLayer(GradientCheckpointLayer): diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 48605c15d30b..ba6ade2b3e9c 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -28,8 +28,6 @@ ) from ..granite.modeling_granite import ( GraniteAttention, - GraniteFlashAttention2, - GraniteSdpaAttention, ) from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -117,19 +115,6 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.scaling = 1 / math.sqrt(self.head_dim) -class GlmFlashAttention2(GlmAttention, GraniteFlashAttention2): - pass - - -class GlmSdpaAttention(GraniteSdpaAttention): - pass - - -GLM_ATTENTION_CLASSES = { - "eager": GlmAttention, - "flash_attention_2": GlmFlashAttention2, - "sdpa": GlmSdpaAttention, -} class GlmDecoderLayer(LlamaDecoderLayer): diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 20bb591bbda4..3f9621c5cc30 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -24,6 +24,7 @@ import torch from torch import nn +from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -46,6 +47,42 @@ _CONFIG_FOR_DOC = "GraniteConfig" +class GraniteRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -255,26 +292,6 @@ def forward( return outputs -class GraniteRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - GraniteRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class GraniteRotaryEmbedding(nn.Module): def __init__( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 0f04ef255c43..22a5b75c7f5d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mixtral/modular_mixtral.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mixtral.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. # @@ -17,142 +23,121 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch Mixtral model.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import is_torch_greater_or_equal_than_1_13 -from ...utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, - replace_return_docstrings, -) -from ...utils.import_utils import is_torch_fx_available +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_mixtral import MixtralConfig -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx - - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" _CONFIG_FOR_DOC = "MixtralConfig" -def load_balancing_loss_func( - gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], - num_experts: Optional[int] = None, - top_k=2, - attention_mask: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, int]: - r""" - Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size - See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss - function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between - experts is too unbalanced. + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - Args: - gate_logits: - Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of - shape [batch_size X sequence_length, num_experts]. - num_experts: - Number of experts - top_k: - The number of experts to route per-token, can be also interpreted as the `top-k` routing - parameter. - attention_mask (`torch.Tensor`, *optional*): - The attention_mask used in forward function - shape [batch_size X sequence_length] if not None. + self.act_fn = ACT2FN[config.hidden_act] - Returns: - The auxiliary loss. + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. """ - if gate_logits is None or not isinstance(gate_logits, tuple): - return 0 - if isinstance(gate_logits, tuple): - compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + # Jitter parameters + self.jitter_noise = config.router_jitter_noise - if attention_mask is None: - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.mean(routing_weights, dim=0) - else: - batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) - # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask - expert_attention_mask = ( - attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) - .reshape(-1, top_k, num_experts) - .to(compute_device) + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert - router_per_expert_attention_mask = ( - attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) - .reshape(-1, num_experts) - .to(compute_device) - ) + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) - # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) - return overall_loss * num_experts + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral class MixtralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -173,45 +158,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache -class MixtralRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -219,9 +165,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb -# TODO @longjie no longer copied from Mistral after static cache -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -229,9 +173,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -242,14 +185,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed -# Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -262,412 +204,85 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class MixtralAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = config.head_dim - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = MixtralRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias ) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache -class MixtralFlashAttention2(MixtralAttention): - """ - Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache -class MixtralSdpaAttention(MixtralAttention): - """ - Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from MixtralAttention.forward def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -MIXTRAL_ATTENTION_CLASSES = { - "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, - "sdpa": MixtralSdpaAttention, -} - - -class MixtralBlockSparseTop2MLP(nn.Module): - def __init__(self, config: MixtralConfig): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralSparseMoeBlock(nn.Module): - """ - This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accommodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - - def __init__(self, config): - super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - - # Jitter parameters - self.jitter_noise = config.router_jitter_noise - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - if self.training and self.jitter_noise > 0: - hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, router_logits + return attn_output, attn_weights class MixtralDecoderLayer(nn.Module): @@ -675,7 +290,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = MixtralAttention(config, layer_idx) self.block_sparse_moe = MixtralSparseMoeBlock(config) self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -751,6 +366,71 @@ def forward( return outputs +class MixtralRotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[MixtralConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + MIXTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -772,17 +452,17 @@ def forward( "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral -# TODO (Raushan): bring back copied after compile compatibility class MixtralPreTrainedModel(PreTrainedModel): config_class = MixtralConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MixtralDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range @@ -817,7 +497,7 @@ def _init_weights(self, module): Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -831,17 +511,24 @@ def _init_weights(self, module): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the @@ -855,9 +542,6 @@ def _init_weights(self, module): output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - output_router_logits (`bool`, *optional*): - Whether or not to return the logits of all the routers. They are useful for computing the router loss, and - should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): @@ -871,8 +555,6 @@ def _init_weights(self, module): "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", MIXTRAL_START_DOCSTRING, ) -# copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral -# TODO @longjie no longer copied from Mistral after static cache class MixtralModel(MixtralPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] @@ -890,10 +572,9 @@ def __init__(self, config: MixtralConfig): self.layers = nn.ModuleList( [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MixtralRotaryEmbedding(config=config) - self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -903,7 +584,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Ignore copy @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( self, @@ -918,7 +598,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: + ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits @@ -1040,16 +720,24 @@ def forward( router_logits=all_router_logits, ) - # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, + use_cache: bool, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -1117,7 +805,6 @@ def _update_causal_mask( return causal_mask @staticmethod - # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mixtral def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, @@ -1185,8 +872,91 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} def __init__(self, config): super().__init__(config) @@ -1196,6 +966,7 @@ def __init__(self, config): self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing self.post_init() @@ -1218,8 +989,7 @@ def get_decoder(self): return self.model @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1236,7 +1006,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, **loss_kwargs, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1327,285 +1097,3 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) - - -@add_start_docstrings( - """ - The Mixtral Model transformer with a sequence classification head on top (linear layer). - - [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - MIXTRAL_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL -class MixtralForSequenceClassification(MixtralPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = MixtralModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - MIXTRAL_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL -class MixtralForTokenClassification(MixtralPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = MixtralModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - MIXTRAL_START_DOCSTRING, -) -# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL -class MixtralForQuestionAnswering(MixtralPreTrainedModel): - base_model_prefix = "model" - - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral - def __init__(self, config): - super().__init__(config) - self.model = MixtralModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 2036775f1228..41e696e0d2ca 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -36,7 +36,7 @@ logging, ) from ..llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm -from ..mistral.modeling_mistral import MistralAttention, MistralPreTrainedModel, MistralForCausalLM +from ..mistral.modeling_mistral import MistralAttention, MistralModel, MistralPreTrainedModel from .configuration_mixtral import MixtralConfig @@ -304,7 +304,7 @@ def forward( -class MixtralModel(MistralPreTrainedModel, MistralForCausalLM): +class MixtralModel(MistralModel): def forward( self, From 0325dc46d07792e6dbe7603465c5d952f6a8ecf5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 13 Dec 2024 19:04:09 +0100 Subject: [PATCH 16/97] qwen2 + starcoder2 --- .../modular-transformers/modeling_dummy.py | 5 +- .../modeling_multimodal1.py | 5 +- .../integrations/flash_attention.py | 5 + src/transformers/modeling_utils.py | 4 +- src/transformers/models/aria/modeling_aria.py | 5 +- .../models/gemma/modular_gemma.py | 2 +- .../models/gemma2/modular_gemma2.py | 12 +- src/transformers/models/glm/modeling_glm.py | 5 +- src/transformers/models/glm/modular_glm.py | 2 - src/transformers/models/gpt2/modeling_gpt2.py | 19 +- .../models/granite/modeling_granite.py | 5 +- .../models/granite/modular_granite.py | 2 + .../models/llama/modeling_llama.py | 5 +- .../models/mistral/modeling_mistral.py | 40 +- .../models/mistral/modular_mistral.py | 14 +- .../models/mixtral/modular_mixtral.py | 8 +- src/transformers/models/olmo/modeling_olmo.py | 5 +- src/transformers/models/olmo/modular_olmo.py | 3 +- .../models/olmo2/modeling_olmo2.py | 5 +- .../models/olmo2/modular_olmo2.py | 1 - src/transformers/models/phi/modeling_phi.py | 5 +- .../models/qwen2/modeling_qwen2.py | 1151 ++++++++++++++++- .../models/qwen2/modular_qwen2.py | 101 ++ .../models/starcoder2/modeling_starcoder2.py | 507 ++------ .../models/starcoder2/modular_starcoder2.py | 404 +----- 25 files changed, 1475 insertions(+), 845 deletions(-) create mode 100644 src/transformers/models/qwen2/modular_qwen2.py diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 8fc0a16b89d2..74ab6c8c979f 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -501,7 +501,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index e4ee9eeb07df..b83b3036ae9f 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -501,7 +501,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 2189baef8158..62e030e90c4f 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -13,6 +13,11 @@ def flash_attention_forward( else: seq_len = query.shape[1] + # FA2 uses non-transposed inputs + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + dropout_rate = config.attention_dropout if training else 0.0 input_dtype = query.dtype diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 410774c5eb91..6a8b59703b9e 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1486,8 +1486,8 @@ def _autoset_attn_implementation( message += ( ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' ) - if cls.model_type + "_"+ config._attn_implementation in ALL_ATTENTION_FUNCTIONS: - config._attn_implementation_internal = cls.model_type+ "_" + config._attn_implementation + if cls.model_type + "_" + config._attn_implementation in ALL_ATTENTION_FUNCTIONS: + config._attn_implementation_internal = cls.model_type + "_" + config._attn_implementation if config._attn_implementation in ALL_ATTENTION_FUNCTIONS: pass else: diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index d1546ae220ec..3eda0afde0f6 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -917,7 +917,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index f00e040f6272..ee0829f857dc 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -347,6 +347,7 @@ def extra_repr(self): ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) + class GemmaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -359,7 +360,6 @@ class GemmaPreTrainedModel(LlamaPreTrainedModel): class GemmaModel(LlamaModel): - def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 9f7daa5f5124..cb7295c2adf5 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -214,7 +214,14 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass -def gemma_eager_attention_forward(config, query, key, value, mask, **_kwargs): +def gemma_eager_attention_forward( + config: Gemma2Config, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: key_states = repeat_kv(key, config.num_key_value_groups) value_states = repeat_kv(value, config.num_key_value_groups) @@ -338,9 +345,11 @@ def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs): "gemma_sdpa": gemma_sdpa_attention_forward, } + class Gemma2Attention(GemmaAttention): pass + class Gemma2DecoderLayer(GemmaDecoderLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() @@ -411,6 +420,7 @@ def forward( class Gemma2PreTrainedModel(GemmaPreTrainedModel): pass + class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): def __init__(self, config: Gemma2Config): super().__init__(config) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f9170625a24d..c7f343b36dc7 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -516,7 +516,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index ba6ade2b3e9c..42192f1b7d6b 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -115,8 +115,6 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.scaling = 1 / math.sqrt(self.head_dim) - - class GlmDecoderLayer(LlamaDecoderLayer): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__() diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 3ed36095eca4..e37a8100760e 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -157,6 +157,7 @@ def eager_attention_forward(self, query, key, value, attention_mask=None, head_m return attn_output, attn_weights + class GPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() @@ -217,8 +218,6 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() @@ -296,8 +295,6 @@ def forward( else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - - input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -316,17 +313,16 @@ def forward( if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) else: attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - **kwargs, - ) + self, + query_states, + key_states, + value_states, + **kwargs, + ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) @@ -351,7 +347,6 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states - class GPT2Block(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 3f9621c5cc30..9ea2c2c3ec16 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -548,7 +548,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 42c36fdaadd0..e276fa5ef1f1 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -29,9 +29,11 @@ logger = logging.get_logger(__name__) + class GraniteRMSNorm(LlamaRMSNorm): pass + class GraniteMLP(LlamaMLP): pass diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ccb15fef2ef5..0c362dd8691d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -536,7 +536,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c983ea15e433..9b707a27b409 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -28,10 +28,27 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "meta-mistral/Mistral-2-7b-hf" + +_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" _CONFIG_FOR_DOC = "MistralConfig" +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class MistralRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -117,22 +134,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class MistralMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -509,7 +510,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 2d36df948d3e..fd0c86218050 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -1,5 +1,6 @@ import torch import torch.utils.checkpoint +from torch import nn from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -7,11 +8,23 @@ LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, + LlamaMLP, LlamaModel, ) from .configuration_mistral import MistralConfig +_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" + + +class MistralMLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + class MistralModel(LlamaModel): def _update_causal_mask( self, @@ -175,4 +188,3 @@ class MistralForSequenceClassification(LlamaForSequenceClassification): class MistralForQuestionAnswering(LlamaForQuestionAnswering): pass - diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 41e696e0d2ca..697388535499 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -36,7 +36,7 @@ logging, ) from ..llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm -from ..mistral.modeling_mistral import MistralAttention, MistralModel, MistralPreTrainedModel +from ..mistral.modeling_mistral import MistralAttention, MistralModel from .configuration_mixtral import MixtralConfig @@ -217,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MixtralRMSNorm(LlamaRMSNorm): pass + class MixtralAttention(MistralAttention): pass @@ -302,10 +303,7 @@ def forward( return outputs - - class MixtralModel(MistralModel): - def forward( self, input_ids: torch.LongTensor = None, @@ -442,7 +440,6 @@ def forward( ) - class MixtralForCausalLM(LlamaForCausalLM): _tied_weights_keys = ["lm_head.weight"] @@ -457,7 +454,6 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() - def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 478343d2d24b..f8f6a884c091 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -541,7 +541,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 8f397c512a75..e43976969ace 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -104,6 +104,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + class OlmoForCausalLM(LlamaForCausalLM): pass @@ -118,5 +119,3 @@ class OlmoForSequenceClassification(LlamaForSequenceClassification): class OlmoForQuestionAnswering(LlamaForQuestionAnswering): pass - - diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 9931346e1c29..31cb6b487144 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -511,7 +511,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 6cee9d8c35a5..a169f72c8411 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -192,7 +192,6 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 73653b404c6a..64364d54b3a4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -540,7 +540,10 @@ def forward( past_key_values = DynamicCache() if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b956578da1f8..dcba03eafe41 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1,78 +1,1111 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Qwen2 model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union, Unpack import torch -import torch.utils.checkpoint +from torch import nn -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...utils import ( - is_flash_attn_2_available, + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings, ) +from .configuration_qwen2 import Qwen2Config + +logger = logging.get_logger(__name__) -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +_CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf" +_CONFIG_FOR_DOC = "Qwen2Config" -logger = logging.get_logger(__name__) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. -def qwen2_flash_attention_forward( - config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs -): + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling if attention_mask is not None: - seq_len = attention_mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - else: - seq_len = query.shape[1] - - dropout_rate = config.attention_dropout if training else 0.0 - - sliding_window = None - if ( - config.use_sliding_window - and getattr(config, "sliding_window", None) is not None - and layer_idx >= config.max_window_layers + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + sliding_window=sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen2DecoderLayer(GradientCheckpointLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class Qwen2RotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[Qwen2Config] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +QWEN2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Qwen2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2PreTrainedModel(PreTrainedModel): + config_class = Qwen2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +QWEN2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", + QWEN2_START_DOCSTRING, +) +class Qwen2Model(Qwen2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] + + Args: + config: Qwen2Config + """ + + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2RotaryEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, ): - sliding_window = config.sliding_window - - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - seq_len, - config=config, - dropout=dropout_rate, - layer_idx=layer_idx, - sliding_window=sliding_window, + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-qwen2/Qwen2-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a sequence classification head on top (linear layer). + + [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForTokenClassification(Qwen2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + QWEN2_START_DOCSTRING, +) +class Qwen2ForQuestionAnswering(Qwen2PreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = Qwen2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() - return attn_output, None + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) -ALL_ATTENTION_FUNCTIONS["qwen2_flash_attention2"] = qwen2_flash_attention_forward + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output -# TODO mask creation is also different for Qwen2 -# MAKE SURE THAT config._attn_implementation is automatically qwen2_flash_attention2 + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py new file mode 100644 index 000000000000..4f55ed765d27 --- /dev/null +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -0,0 +1,101 @@ +from typing import Callable, List, Optional, Tuple, Union, Unpack + +import torch +from torch import nn +import torch.utils.checkpoint + +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...utils import logging +from ..llama.modeling_llama import ( + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaModel, + LlamaForCausalLM, + LlamaAttention, + LlamaDecoderLayer, + eager_attention_forward, + apply_rotary_pos_emb, +) +from .configuration_qwen2 import Qwen2Config + + +logger = logging.get_logger(__name__) + +class Qwen2Attention(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + sliding_window = None + if ( + self.config.use_sliding_window + and getattr(self.config, "sliding_window", None) is not None + and self.layer_idx >= self.config.max_window_layers + ): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + sliding_window=sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen2DecoderLayer(LlamaDecoderLayer): + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__() + if config.sliding_window and config._attn_implementation != "flash_attention_2": + logger.warning_once( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + +class Qwen2Model(LlamaModel): + pass + +class Qwen2ForCausalLM(LlamaForCausalLM): + pass + +class Qwen2ForSequenceClassification(LlamaForSequenceClassification): + pass + +class Qwen2ForTokenClassification(LlamaForTokenClassification): + pass + +class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): + pass \ No newline at end of file diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ed1651996793..263d77e6ca16 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -24,16 +24,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -41,23 +41,19 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_starcoder2 import Starcoder2Config -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" _CONFIG_FOR_DOC = "Starcoder2Config" @@ -191,309 +187,101 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + class Starcoder2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.use_bias = config.use_bias - self.is_causal = True + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) - def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2FlashAttention2(Starcoder2Attention): - """ - Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reshape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2SdpaAttention(Starcoder2Attention): - """ - Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - # The difference with Mistral is that here it uses dropout attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - return attn_output, None, past_key_value + return attn_output, attn_weights -STARCODER2_ATTENTION_CLASSES = { - "eager": Starcoder2Attention, - "flash_attention_2": Starcoder2FlashAttention2, - "sdpa": Starcoder2SdpaAttention, -} - - -class Starcoder2DecoderLayer(nn.Module): +class Starcoder2DecoderLayer(GradientCheckpointLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) self.mlp = Starcoder2MLP(config) - self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -502,35 +290,13 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -545,6 +311,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) hidden_states = residual + hidden_states @@ -591,7 +358,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Starcoder2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True @@ -631,7 +398,7 @@ def _init_weights(self, module): Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] @@ -706,12 +473,10 @@ def __init__(self, config: Starcoder2Config): self.layers = nn.ModuleList( [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self._attn_implementation = config._attn_implementation self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.rotary_emb = Starcoder2RotaryEmbedding(config=config) - - self.gradient_checkpointing = False self.embedding_dropout = config.embedding_dropout + # Initialize weights and apply final processing self.post_init() @@ -727,54 +492,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -791,41 +545,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -835,18 +573,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -866,30 +599,21 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: + if using_static_cache: target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -906,8 +630,6 @@ def _update_causal_mask( device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, ) if ( @@ -919,6 +641,7 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -932,8 +655,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, - config: Starcoder2Config, - past_key_values: Cache, + **kwargs, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -941,11 +663,13 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -954,10 +678,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. - config (`Starcoder2Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -967,30 +687,25 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) + return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -1029,7 +744,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1038,7 +753,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1059,8 +774,8 @@ def forward( ```python >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM - >>> model = Starcoder2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = Starcoder2ForCausalLM.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-starcoder2/Starcoder2-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") @@ -1070,7 +785,6 @@ def forward( >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1089,6 +803,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1097,7 +812,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 013c8e472b32..bafe3a4fba59 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -19,8 +19,7 @@ # limitations under the License. """PyTorch Starcoder2 model.""" -import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -28,28 +27,32 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, ) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import ( add_start_docstrings_to_model_forward, is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) -from ..qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM, Qwen2Model, Qwen2PreTrainedModel +from ..qwen2.modeling_qwen2 import Qwen2ForCausalLM, Qwen2Model from .configuration_starcoder2 import Starcoder2Config if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass logger = logging.get_logger(__name__) @@ -79,317 +82,80 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -class Starcoder2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ +def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): + config = attention_class.config + key_states = repeat_kv(key, attention_class.num_key_value_groups) + value_states = repeat_kv(value, attention_class.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class Starcoder2Attention(LlamaAttention): def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.rope_theta = config.rope_theta - self.use_bias = config.use_bias - self.is_causal = True self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias) - def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights += causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2FlashAttention2(Starcoder2Attention): - """ - Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reshape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attention_mask, - q_len, - position_ids=position_ids, dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Starcoder2SdpaAttention(Starcoder2Attention): - """ - Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, + **kwargs, ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - # The difference with Mistral is that here it uses dropout attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) - return attn_output, None, past_key_value + return attn_output, attn_weights -STARCODER2_ATTENTION_CLASSES = { - "eager": Starcoder2Attention, - "flash_attention_2": Starcoder2FlashAttention2, - "sdpa": Starcoder2SdpaAttention, -} - - -class Starcoder2DecoderLayer(Qwen2DecoderLayer, nn.Module): +class Starcoder2DecoderLayer(LlamaDecoderLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - - self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - - self.mlp = Starcoder2MLP(config) - + super().__init__(self) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) -class Starcoder2PreTrainedModel(Qwen2PreTrainedModel): - pass - - STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined @@ -412,54 +178,43 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -476,41 +231,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: + for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -520,18 +259,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() class Starcoder2ForCausalLM(Qwen2ForCausalLM): From ecd814bdc28bfb559756ee3aa3a5082c2fd401ef Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 12:28:27 +0100 Subject: [PATCH 17/97] mostly gemma2 --- .../modular-transformers/modeling_dummy.py | 2 +- .../modeling_multimodal1.py | 2 +- .../modeling_my_new_model2.py | 92 ++++----- .../modular-transformers/modeling_super.py | 2 +- .../integrations/flash_attention.py | 31 ++- .../integrations/flex_attention.py | 20 +- .../integrations/sdpa_attention.py | 18 +- src/transformers/modeling_utils.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- .../models/cohere2/modeling_cohere2.py | 4 - .../models/gemma/modeling_gemma.py | 92 ++++----- .../models/gemma/modular_gemma.py | 1 + .../models/gemma2/modeling_gemma2.py | 42 +++- .../models/gemma2/modular_gemma2.py | 183 +++++++----------- src/transformers/models/glm/modeling_glm.py | 2 +- .../models/llama/modeling_llama.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2/modular_qwen2.py | 27 +-- .../models/starcoder2/modeling_starcoder2.py | 2 +- 21 files changed, 272 insertions(+), 260 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 74ab6c8c979f..1d336845de23 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -272,7 +272,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index b83b3036ae9f..78e37aa123d1 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -272,7 +272,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index fa89cc368380..cc26a94defc6 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -207,7 +207,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -245,6 +245,51 @@ def forward( return outputs +MY_NEW_MODEL2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MyNewModel2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.", + MY_NEW_MODEL2_START_DOCSTRING, +) +class MyNewModel2PreTrainedModel(PreTrainedModel): + config_class = MyNewModel2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MyNewModel2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + class MyNewModel2RotaryEmbedding(nn.Module): def __init__( self, @@ -310,51 +355,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -MY_NEW_MODEL2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`MyNewModel2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.", - MY_NEW_MODEL2_START_DOCSTRING, -) -class MyNewModel2PreTrainedModel(PreTrainedModel): - config_class = MyNewModel2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["MyNewModel2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - MY_NEW_MODEL2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 0bec83258924..cfbaa5ac6e64 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -269,7 +269,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 62e030e90c4f..333a22bbcf98 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,13 +1,25 @@ +from typing import Optional + import torch from ..modeling_flash_attention_utils import _flash_attention_forward def flash_attention_forward( - config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + sliding_window: Optional[int] = None, + softcap: Optional[float] = None, + target_dtype: torch.dtype = torch.float16, + **kwargs, ): - if attentions_mask is not None: - seq_len = attentions_mask.shape[1] + if attention_mask is not None: + seq_len = attention_mask.shape[1] query = query[:, :, :seq_len] value = value[:, :, :seq_len] else: @@ -18,8 +30,6 @@ def flash_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) - dropout_rate = config.attention_dropout if training else 0.0 - input_dtype = query.dtype if input_dtype == torch.float32: query = query.to(target_dtype) @@ -30,11 +40,14 @@ def flash_attention_forward( query, key, value, - attentions_mask, + attention_mask, seq_len, - config=config, - dropout=dropout_rate, - layer_idx=layer_idx, + module.is_causal, + dropout=dropout, + softmax_scale=scaling, + sliding_window=sliding_window, + softcap=softcap, + use_top_left_mask=module._flash_attn_uses_top_left_mask, **kwargs, ) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 9c309a9ad505..dd4287921d2c 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,3 +1,7 @@ +from typing import Optional + +import torch + from ..utils import is_torch_greater_or_equal @@ -5,12 +9,23 @@ from torch.nn.attention.flex_attention import flex_attention -def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs): +def flex_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +): causal_mask = attention_mask if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] def causal_mod(score, b, h, q_idx, kv_idx): + if softcap is not None: + score = softcap * torch.tanh(score / softcap) if causal_mask is not None: score += causal_mask[b][0][q_idx][kv_idx] return score @@ -21,8 +36,9 @@ def causal_mod(score, b, h, q_idx, kv_idx): value, score_mod=causal_mod, enable_gqa=True, - scale=module.scaling, + scale=scaling, return_lse=True, ) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attention_weights diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 0cf58f035ea9..3a90ef9af282 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch @@ -13,7 +15,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs): +def sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) @@ -31,9 +42,10 @@ def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kw key, value, attn_mask=causal_mask, - dropout_p=module.config.attention_dropout if module.training else 0.0, + dropout_p=dropout, + scale=scaling, is_causal=is_causal, - scale=module.scaling, ) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6a8b59703b9e..2cb4c38247d2 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from functools import partial, wraps from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from zipfile import is_zipfile import torch diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 3eda0afde0f6..dc18a8324207 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -588,7 +588,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 6b19d178341f..c328690efc8c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -659,10 +659,6 @@ def __init__(self, config: Cohere2Config): [Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) - - self.gradient_checkpointing = False - if getattr(config, "pretraining_tp", 1) != 1: - logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.") self.rotary_emb = Cohere2RotaryEmbedding(config=config) # Initialize weights and apply final processing diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 12b6cd81451b..ced8b374ba04 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -237,7 +237,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -275,6 +275,51 @@ def forward( return outputs +GEMMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GemmaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma Model outputting raw hidden-states without any specific head on top.", + GEMMA_START_DOCSTRING, +) +class GemmaPreTrainedModel(PreTrainedModel): + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + class GemmaRotaryEmbedding(nn.Module): def __init__( self, @@ -340,51 +385,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -GEMMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`GemmaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Gemma Model outputting raw hidden-states without any specific head on top.", - GEMMA_START_DOCSTRING, -) -class GemmaPreTrainedModel(PreTrainedModel): - config_class = GemmaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["GemmaDecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index ee0829f857dc..69972013ab28 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -32,6 +32,7 @@ LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, + LlamaPreTrainedModel, ) from ..llama.tokenization_llama import LlamaTokenizer diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 73536bd0f840..3f66f2254cf0 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -50,6 +50,8 @@ logger = logging.get_logger(__name__) +_CHECKPOINT_FOR_DOC = "google/gemma2-7b" +_CONFIG_FOR_DOC = "Gemma2Config" class Gemma2RMSNorm(nn.Module): @@ -198,18 +200,30 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor], + **kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling + + if module.attn_logit_softcapping is not None: + attn_weights = attn_weights / module.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * module.attn_logit_softcapping + if mask is not None: # no matter the length, we just slice it + causal_mask = mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask + # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -224,7 +238,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 + self.scaling = config.query_pre_attn_scalar**-0.5 self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -238,6 +252,10 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( self, @@ -271,6 +289,10 @@ def forward( query_states, key_states, value_states, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, **kwargs, ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index cb7295c2adf5..eecc2f886109 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -22,13 +22,15 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache from ...configuration_utils import PretrainedConfig +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import ( is_flash_attn_2_available, - is_flash_attn_greater_or_equal, is_torch_greater_or_equal, logging, ) @@ -42,15 +44,16 @@ GemmaPreTrainedModel, GemmaRMSNorm, GemmaRotaryEmbedding, + apply_rotary_pos_emb, repeat_kv, ) if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward + pass if is_torch_greater_or_equal("2.5"): - from torch.nn.attention.flex_attention import flex_attention + pass _CHECKPOINT_FOR_DOC = "google/gemma2-7b" @@ -214,140 +217,86 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): pass -def gemma_eager_attention_forward( - config: Gemma2Config, +def eager_attention_forward( + module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor], - **_kwargs, + **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - key_states = repeat_kv(key, config.num_key_value_groups) - value_states = repeat_kv(value, config.num_key_value_groups) + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * config.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling - if config.attn_logit_softcapping is not None: - attn_weights = attn_weights / config.attn_logit_softcapping + if module.attn_logit_softcapping is not None: + attn_weights = attn_weights / module.attn_logit_softcapping attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * config.attn_logit_softcapping + attn_weights = attn_weights * module.attn_logit_softcapping if mask is not None: # no matter the length, we just slice it causal_mask = mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=config.training) + attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights -def gemma_flash_attention_forward(config, query, key, value, mask, target_dtype=torch.float16, **_kwargs): - if mask is not None: - seq_len = mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding - query_states = query.transpose(1, 2) - key_states = key.transpose(1, 2) - value_states = value.transpose(1, 2) - - dropout_rate = config.attention_dropout if config.training else 0.0 - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - seq_len, - dropout=dropout_rate, - softmax_scale=config.scaling, - is_causal=config.is_causal, - sliding_window=config.sliding_window, - use_top_left_mask=config._flash_attn_uses_top_left_mask, - softcap=config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, - ) - - return attn_output, None - - -def gemma_flex_attention_forward(config, query, key, value, mask, output_attentions=False, **_kwargs): - def tanh_softcap(score, b, h, q_idx, kv_idx): - soft_cap = config.attn_logit_softcapping - score = soft_cap * torch.tanh(score / soft_cap) - if mask is not None: - return score + mask[b][0][q_idx][kv_idx] - return score - - attn_output = flex_attention( - query, - key, - value, - score_mod=tanh_softcap, - enable_gqa=True, - scale=config.scaling, - return_lse=output_attentions, - ) - if not output_attentions: - attn_weights = None - else: - attn_output, attn_weights = attn_output - - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -def gemma_sdpa_attention_forward(config, query, key, value, mask, **_kwargs): - key = repeat_kv(key, config.num_key_value_groups) - value = repeat_kv(value, config.num_key_value_groups) - - causal_mask = mask - if mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query.device.type == "cuda" and causal_mask is not None: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and query.shape[1] > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=config.attention_dropout if config.training else 0.0, - is_causal=is_causal, - scale=config.scaling, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - -ALL_ATTENTION_FUNCTION = { - "gemma_flash_attention_2": gemma_flash_attention_forward, - "gemma_flex_attention": gemma_flex_attention_forward, - "gemma_eager": gemma_eager_attention_forward, - "gemma_sdpa": gemma_sdpa_attention_forward, -} +class Gemma2Attention(GemmaAttention): + def __init__(self, config: Gemma2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.attn_logit_softcapping = self.config.attn_logit_softcapping + self.attention_dropout = self.config.attention_dropout + self.is_causal = True + self.scaling = config.query_pre_attn_scalar**-0.5 + self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + dropout = self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + softcap=self.attn_logit_softcapping, + **kwargs, + ) -class Gemma2Attention(GemmaAttention): - pass + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights class Gemma2DecoderLayer(GemmaDecoderLayer): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index c7f343b36dc7..0f14c7140efe 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -282,7 +282,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0c362dd8691d..9726d22fed37 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -307,7 +307,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9b707a27b409..2a4dcbefc344 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -281,7 +281,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f8f6a884c091..f634c52f2fc7 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -247,7 +247,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index dcba03eafe41..24c1c6ce3e65 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -237,7 +237,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 4f55ed765d27..edadaa6e1188 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -1,32 +1,30 @@ -from typing import Callable, List, Optional, Tuple, Union, Unpack +from typing import Callable, Optional, Tuple, Unpack import torch -from torch import nn import torch.utils.checkpoint -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...cache_utils import Cache, SlidingWindowCache, StaticCache -from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import logging from ..llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, - LlamaForCausalLM, - LlamaAttention, - LlamaDecoderLayer, - eager_attention_forward, apply_rotary_pos_emb, + eager_attention_forward, ) from .configuration_qwen2 import Qwen2Config logger = logging.get_logger(__name__) -class Qwen2Attention(LlamaAttention): +class Qwen2Attention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, @@ -75,7 +73,7 @@ def forward( attn_output = self.o_proj(attn_output) return attn_output, attn_weights - + class Qwen2DecoderLayer(LlamaDecoderLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() @@ -85,17 +83,22 @@ def __init__(self, config: Qwen2Config, layer_idx: int): "unexpected results may be encountered." ) + class Qwen2Model(LlamaModel): pass + class Qwen2ForCausalLM(LlamaForCausalLM): pass + class Qwen2ForSequenceClassification(LlamaForSequenceClassification): pass + class Qwen2ForTokenClassification(LlamaForTokenClassification): pass + class Qwen2ForQuestionAnswering(LlamaForQuestionAnswering): - pass \ No newline at end of file + pass diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 263d77e6ca16..5ca59d1257f0 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -294,7 +294,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states From f5fc638db017bc9c16c1630919fa3290bc24b7ed Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 12:39:15 +0100 Subject: [PATCH 18/97] Update image_processing_auto.py --- src/transformers/models/auto/image_processing_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index 00b0c049c91d..db25591eaa35 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -30,6 +30,7 @@ CONFIG_NAME, IMAGE_PROCESSOR_NAME, get_file_from_repo, + is_timm_config_dict, is_timm_local_checkpoint, is_torchvision_available, is_vision_available, @@ -42,7 +43,6 @@ model_type_to_module_name, replace_list_option_in_docstrings, ) -from .utils.generic import is_timm_config_dict logger = logging.get_logger(__name__) From 5e56d9c0b0607a91d65e6c5195f2710932e7260c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 12:46:34 +0100 Subject: [PATCH 19/97] fix --- src/transformers/modeling_utils.py | 4 +++- src/transformers/models/gemma2/modular_gemma2.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2cb4c38247d2..16ce1dbb0ec0 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -30,7 +30,7 @@ from dataclasses import dataclass from functools import partial, wraps from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union from zipfile import is_zipfile import torch @@ -174,6 +174,8 @@ def is_local_dist_rank_0(): if is_peft_available(): from .utils import find_adapter_config_file +SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") + TORCH_INIT_FUNCTIONS = { "uniform_": nn.init.uniform_, "normal_": nn.init.normal_, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index eecc2f886109..0a2e9e23fe01 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -287,7 +287,7 @@ def forward( query_states, key_states, value_states, - dropout = self.attention_dropout if self.training else 0.0, + dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, softcap=self.attn_logit_softcapping, From 598b7bb58b558acddda03f13fc13ffc1a215b36e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 12:52:04 +0100 Subject: [PATCH 20/97] Update modular_starcoder2.py --- src/transformers/models/starcoder2/modular_starcoder2.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index bafe3a4fba59..e8c1d5ced253 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -132,14 +132,12 @@ def forward( if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - dropout_rate = 0.0 if not self.training else self.attention_dropout - attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, - dropout=dropout_rate, + dropout=0.0 if not self.training else self.attention_dropout, **kwargs, ) @@ -283,7 +281,7 @@ class Starcoder2ForTokenClassification(LlamaForTokenClassification): __all__ = [ "Starcoder2ForCausalLM", "Starcoder2Model", - "Starcoder2PreTrainedModel", + "Starcoder2PreTrainedModel", # noqa: F822 "Starcoder2ForSequenceClassification", "Starcoder2ForTokenClassification", ] From 0f565fbfd7d5f6c315e16951bd4b288138c5891e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 12:53:07 +0100 Subject: [PATCH 21/97] fix --- src/transformers/models/starcoder2/modeling_starcoder2.py | 4 +--- src/transformers/models/starcoder2/modular_starcoder2.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5ca59d1257f0..ee515d77611f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -257,14 +257,12 @@ def forward( if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - dropout_rate = 0.0 if not self.training else self.attention_dropout - attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, - dropout=dropout_rate, + dropout=0.0 if not self.training else self.attention_dropout, **kwargs, ) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index e8c1d5ced253..d4163b9f478b 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -281,7 +281,7 @@ class Starcoder2ForTokenClassification(LlamaForTokenClassification): __all__ = [ "Starcoder2ForCausalLM", "Starcoder2Model", - "Starcoder2PreTrainedModel", # noqa: F822 + "Starcoder2PreTrainedModel", # noqa: F822 "Starcoder2ForSequenceClassification", "Starcoder2ForTokenClassification", ] From c9ac84d499c5f1bf3f0abd52d6e4b5a101901780 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 13:20:53 +0100 Subject: [PATCH 22/97] remove all copied from attentions --- .../modeling_flash_attention_utils.py | 2 +- src/transformers/models/bark/modeling_bark.py | 1 - src/transformers/models/bart/modeling_bart.py | 1 - .../models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/clip/modeling_clip.py | 1 - .../models/cohere/modeling_cohere.py | 3 +- .../data2vec/modeling_data2vec_audio.py | 1 - src/transformers/models/dbrx/modeling_dbrx.py | 1 - .../models/distilbert/modeling_distilbert.py | 1 - .../models/falcon/modeling_falcon.py | 1 - .../gpt_bigcode/modeling_gpt_bigcode.py | 1 - .../models/gpt_neo/modeling_gpt_neo.py | 1 - src/transformers/models/gptj/modeling_gptj.py | 1 - .../models/granitemoe/modeling_granitemoe.py | 6 +- .../models/hubert/modeling_hubert.py | 1 - .../models/idefics2/modeling_idefics2.py | 5 +- .../models/idefics3/modeling_idefics3.py | 1 - .../models/jamba/modeling_jamba.py | 1 - .../models/jetmoe/modeling_jetmoe.py | 1 - .../models/m2m_100/modeling_m2m_100.py | 1 - .../models/mbart/modeling_mbart.py | 1 - src/transformers/models/mimi/modeling_mimi.py | 6 +- .../models/mixtral/modeling_mixtral.py | 291 +++++++++++++++++- .../models/mixtral/modular_mixtral.py | 20 +- .../models/moshi/modeling_moshi.py | 6 +- .../models/musicgen/modeling_musicgen.py | 1 - .../modeling_musicgen_melody.py | 1 - .../models/nemotron/modeling_nemotron.py | 6 +- .../models/olmoe/modeling_olmoe.py | 1 - src/transformers/models/opt/modeling_opt.py | 1 - src/transformers/models/phi3/modeling_phi3.py | 5 +- .../qwen2_audio/modeling_qwen2_audio.py | 1 - .../models/qwen2_moe/modeling_qwen2_moe.py | 7 +- src/transformers/models/sew/modeling_sew.py | 1 - .../models/siglip/modeling_siglip.py | 1 - .../models/stablelm/modeling_stablelm.py | 1 - .../models/unispeech/modeling_unispeech.py | 1 - .../unispeech_sat/modeling_unispeech_sat.py | 1 - .../models/wav2vec2/modeling_wav2vec2.py | 1 - .../models/whisper/modeling_whisper.py | 1 - .../models/zamba/modeling_zamba.py | 1 - 41 files changed, 337 insertions(+), 51 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index ec03ba1eb5fd..ba034772e03c 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -276,7 +276,7 @@ def _flash_attention_forward( if not use_top_left_mask: causal = is_causal else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. causal = is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 9e225ac9ae15..36a278263b55 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -197,7 +197,6 @@ class BarkSelfFlashAttention2(BarkSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index dd1b69c8127f..4e1f0b389d42 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -294,7 +294,6 @@ class BartFlashAttention2(BartAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index f01665201bfa..11bc411a00c0 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -362,7 +362,7 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon +# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon # TODO(joao): add me back asap :) class ChameleonFlashAttention2(ChameleonAttention): """ diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 4751bb91aace..0bd9c9c0abce 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -401,7 +401,6 @@ class CLIPFlashAttention2(CLIPAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b9a235ed500c..72f6f1c82e86 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -351,7 +351,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere +# TODO cyril: modular class CohereFlashAttention2(CohereAttention): """ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 590509eaf905..03102d22ca0d 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -489,7 +489,6 @@ class Data2VecAudioFlashAttention2(Data2VecAudioAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7d20b766658f..990b4d821e5b 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -318,7 +318,6 @@ class DbrxFlashAttention2(DbrxAttention): calls the public API of flash attention. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 36e35594b3d3..a826272956e5 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -245,7 +245,6 @@ class DistilBertFlashAttention2(MultiHeadSelfAttention): API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 51d9ff39d48f..b2ca3ffed67b 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -492,7 +492,6 @@ class FalconFlashAttention2(FalconAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 5326c7b907d4..403159cdf39c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -278,7 +278,6 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 28bfbabc1fd8..6763695bfba0 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -278,7 +278,6 @@ class GPTNeoFlashAttention2(GPTNeoSelfAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 1cc9cf369d18..4af8f73b5f5e 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -266,7 +266,6 @@ class GPTJFlashAttention2(GPTJAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 9f5fdeea07d4..6d706c463270 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -510,7 +510,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe +# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe +# TODO cyril: modular class GraniteMoeFlashAttention2(GraniteMoeAttention): """ GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays @@ -617,7 +618,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe +# NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteSdpaAttention with Granite->GraniteMoe +# TODO cyril: modular class GraniteMoeSdpaAttention(GraniteMoeAttention): """ GraniteMoe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 03904a6abfa0..1629f7d4f3fe 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -563,7 +563,6 @@ class HubertFlashAttention2(HubertAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 3d46c3bd82e7..6d7295b5120d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -272,7 +272,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -859,7 +858,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 +# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 +# TODO cyril: modular class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): """ Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays @@ -867,7 +867,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 31d43948fbd5..3a52b8b6d54d 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -273,7 +273,6 @@ class Idefics3VisionFlashAttention2(Idefics3VisionAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index a185d5ebc6e8..07643d243e2a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -384,7 +384,6 @@ class JambaFlashAttention2(JambaAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index a4bb1d78fdc5..cf30e9c02405 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -641,7 +641,6 @@ def forward( class JetMoeFlashAttention2(JetMoeAttention): - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index cc35a3504255..4e116e7e3db5 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -348,7 +348,6 @@ class M2M100FlashAttention2(M2M100Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 95cd7c65ef32..e272c98f0697 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -291,7 +291,6 @@ class MBartFlashAttention2(MBartAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index cbdd2c663c58..3ba64a554831 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -559,7 +559,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi +# TODO cyril: modular class MimiFlashAttention2(MimiAttention): """ Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays @@ -670,7 +671,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Mimi +# TODO cyril: modular class MimiSdpaAttention(MimiAttention): """ Mimi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 22a5b75c7f5d..53482843b3c0 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -40,15 +40,26 @@ CausalLMOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_mixtral import MixtralConfig logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1" _CONFIG_FOR_DOC = "MixtralConfig" @@ -1097,3 +1108,281 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a sequence classification head on top (linear layer). + + [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MIXTRAL_START_DOCSTRING, +) +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + MIXTRAL_START_DOCSTRING, +) +class MixtralForTokenClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.config) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + MIXTRAL_START_DOCSTRING, +) +class MixtralForQuestionAnswering(MixtralPreTrainedModel): + base_model_prefix = "transformer" + + def __init__(self, config): + super().__init__(config) + self.transformer = MixtralModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 697388535499..a925998a5228 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -35,7 +35,13 @@ from ...utils import ( logging, ) -from ..llama.modeling_llama import LlamaForCausalLM, LlamaRMSNorm +from ..llama.modeling_llama import ( + LlamaForCausalLM, + LlamaForQuestionAnswering, + LlamaForSequenceClassification, + LlamaForTokenClassification, + LlamaRMSNorm, +) from ..mistral.modeling_mistral import MistralAttention, MistralModel from .configuration_mixtral import MixtralConfig @@ -561,3 +567,15 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) + + +class MixtralForSequenceClassification(LlamaForSequenceClassification): + pass + + +class MixtralForTokenClassification(LlamaForTokenClassification): + pass + + +class MixtralForQuestionAnswering(LlamaForQuestionAnswering): + pass diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 82abfa66c2e8..f3d2f2d9d93d 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -527,7 +527,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi +# TODO cyril: modular class MoshiFlashAttention2(MoshiAttention): """ Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays @@ -643,7 +644,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +# NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi +# TODO cyril: modular class MoshiSdpaAttention(MoshiAttention): """ Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 109ddfb626d2..f83bccb7e4f6 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -324,7 +324,6 @@ class MusicgenFlashAttention2(MusicgenAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 61f2ce414e1d..dc0e9b882b20 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -340,7 +340,6 @@ class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 78dace1a53ce..a7cb80a19433 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -301,7 +301,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO cyril: modular class NemotronFlashAttention2(NemotronAttention): """ Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays @@ -415,7 +416,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# NO LONGER EXIST Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# TODO cyril: modular class NemotronSdpaAttention(NemotronAttention): """ Nemotron attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 4398e2f5c9a1..a4b20ee1207e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -422,7 +422,6 @@ class OlmoeFlashAttention2(OlmoeAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index e4ef510f099d..3350ae1a23c2 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -257,7 +257,6 @@ class OptFlashAttention2(OPTAttention): attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index bae3f6d4cdae..d9d7db63e27f 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -431,7 +431,6 @@ class Phi3FlashAttention2(Phi3Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -550,8 +549,8 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 -# TODO @Arthur no longer copied from LLama after static cache +# NO LONGER EXIST copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO cyril: modular class Phi3SdpaAttention(Phi3Attention): """ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index ce0e427048cf..44a5b5ce3155 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -223,7 +223,6 @@ class Qwen2AudioFlashAttention2(Qwen2AudioAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a1e36b8ad7bc..0de6d3a1a727 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -419,7 +419,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe +# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe +# TODO cyril: modular class Qwen2MoeFlashAttention2(Qwen2MoeAttention): """ Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` @@ -429,7 +430,6 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): config.max_window_layers layers. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -530,7 +530,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe +# NO LONGER EXIST Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe +# TODO cyril: modular class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ Qwen2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 8638d9338584..1959d21e1d5d 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -563,7 +563,6 @@ class SEWFlashAttention2(SEWAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index a42bcd0e1746..9a2dfe013716 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -438,7 +438,6 @@ class SiglipFlashAttention2(SiglipAttention): is_causal = False - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 0ce550697e79..904f4a303ade 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -472,7 +472,6 @@ class StableLmFlashAttention2(StableLmAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 6ce5e77706d3..d14964322795 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -595,7 +595,6 @@ class UniSpeechFlashAttention2(UniSpeechAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 52d82ea73942..49551b73577a 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -612,7 +612,6 @@ class UniSpeechSatFlashAttention2(UniSpeechSatAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index bf1bb7746ce8..ca743e1eaef3 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -659,7 +659,6 @@ class Wav2Vec2FlashAttention2(Wav2Vec2Attention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ce3df3e16707..fb01823a29c0 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -354,7 +354,6 @@ class WhisperFlashAttention2(WhisperAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index dee7f898fcf9..e49a921e3415 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -312,7 +312,6 @@ class ZambaFlashAttention2(ZambaAttention): flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From d189fe743aa4e6aa10622eefd68b2b44e2a57496 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 16 Dec 2024 13:50:07 +0100 Subject: [PATCH 23/97] remove gcv --- src/transformers/modeling_utils.py | 217 +++++++----------- src/transformers/models/aria/modeling_aria.py | 4 +- .../models/gemma/modular_gemma.py | 9 + .../models/gemma2/modular_gemma2.py | 21 +- src/transformers/models/glm/modeling_glm.py | 4 +- .../models/llama/modeling_llama.py | 4 +- .../models/mistral/modeling_mistral.py | 4 +- 7 files changed, 109 insertions(+), 154 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 16ce1dbb0ec0..c502742429e7 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2517,16 +2517,91 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - for layer in list(self.modules()): - if isinstance(layer, GradientCheckpointLayer): - layer.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + """ + Activates gradient checkpointing for the current model. - @property - def gradient_checkpointing(self, gradient_checkpointing_kwargs=None): - for layer in list(self.modules()): - if isinstance(layer, GradientCheckpointLayer): - return layer.gradient_checkpointing - return False + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + + def gradient_checkpointing_disable(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=False) + else: + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() @property def is_gradient_checkpointing(self) -> bool: @@ -5575,127 +5650,3 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): "sdpa": sdpa_attention_forward, } ) - - -class GradientCheckpointLayer(torch.nn.Module): - def __init__(self, *args, **kwargs): - self.gradient_checkpointing = False - super().__init__(*args, **kwargs) - - def __call__(self, *args, **kwargs): - """ - Adjust the behavior of the inherited class by overriding `__call__`. - - Automatically handles gradient checkpointing based on flags in the provided arguments. - """ - # Extract necessary flags and arguments - gradient_checkpointing = kwargs.pop("gradient_checkpointing", False) | getattr( - self, "gradient_checkpointing", False - ) - training = self.training - - if gradient_checkpointing and training: - # Use gradient checkpointing - return self._apply_gradient_checkpointing(*args, **kwargs) - else: - # Default behavior: call the original `forward` method - return self.forward(*args, **kwargs) - - def _apply_gradient_checkpointing(self, *args, **kwargs): - """ - Apply gradient checkpointing using the appropriate function. - - By default, uses `torch.utils.checkpoint.checkpoint`. - """ - - # Assume `self.forward` is compatible with checkpointing - def wrapped_forward(): - return self.forward(*args, **kwargs) - - return self._gradient_checkpointing_func(wrapped_forward) - - def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - """ - Activates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - - We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of - the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - - Args: - gradient_checkpointing_kwargs (dict, *optional*): - Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. - """ - self.gradient_checkpointing = True - - if gradient_checkpointing_kwargs is None: - gradient_checkpointing_kwargs = {"use_reentrant": True} - - gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) - - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` method - _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters - - if not _is_using_old_format: - self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) - else: - self.apply(partial(self._set_gradient_checkpointing, value=True)) - logger.warning( - "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." - "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." - ) - - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() - - def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): - is_gradient_checkpointing_set = False - - # Apply it on the top-level module in case the top-level modules supports it - # for example, LongT5Stack inherits from `PreTrainedModel`. - if hasattr(self, "gradient_checkpointing"): - self._gradient_checkpointing_func = gradient_checkpointing_func - self.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - for module in self.modules(): - if hasattr(module, "gradient_checkpointing"): - module._gradient_checkpointing_func = gradient_checkpointing_func - module.gradient_checkpointing = enable - is_gradient_checkpointing_set = True - - if not is_gradient_checkpointing_set: - raise ValueError( - f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" - " `gradient_checkpointing` to modules of the model that uses checkpointing." - ) - - def gradient_checkpointing_disable(self): - """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". - """ - if self.supports_gradient_checkpointing: - # For old GC format (transformers < 4.35.0) for models that live on the Hub - # we will fall back to the overwritten `_set_gradient_checkpointing` methid - _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters - if not _is_using_old_format: - self._set_gradient_checkpointing(enable=False) - else: - logger.warning( - "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." - "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." - ) - self.apply(partial(self._set_gradient_checkpointing, value=False)) - - if getattr(self, "_hf_peft_config_loaded", False): - self.disable_input_require_grads() diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index dc18a8324207..932fff4e528f 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -28,7 +28,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, @@ -557,7 +557,7 @@ def forward( return attn_output, attn_weights -class AriaTextDecoderLayer(GradientCheckpointLayer): +class AriaTextDecoderLayer(nn.Module): """ Aria Text Decoder Layer. diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 69972013ab28..f86f0c5373d8 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -361,6 +361,15 @@ class GemmaPreTrainedModel(LlamaPreTrainedModel): class GemmaModel(LlamaModel): + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet! + self.post_init() + def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 0a2e9e23fe01..b89596f03eb9 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,11 +29,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import ( - is_flash_attn_2_available, - is_torch_greater_or_equal, - logging, -) +from ...utils import logging from ..gemma.modeling_gemma import ( GemmaAttention, GemmaDecoderLayer, @@ -49,13 +45,6 @@ ) -if is_flash_attn_2_available(): - pass - -if is_torch_greater_or_equal("2.5"): - pass - - _CHECKPOINT_FOR_DOC = "google/gemma2-7b" logger = logging.get_logger(__name__) @@ -299,10 +288,16 @@ def forward( return attn_output, attn_weights -class Gemma2DecoderLayer(GemmaDecoderLayer): +class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() + self.hidden_size = config.hidden_size + self.config = config self.is_sliding = not bool(layer_idx % 2) + self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) + self.mlp = Gemma2MLP(config) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0f14c7140efe..f23130db819b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -36,7 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, @@ -262,7 +262,7 @@ def forward( return attn_output, attn_weights -class GlmDecoderLayer(GradientCheckpointLayer): +class GlmDecoderLayer(nn.ModuleLayer): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9726d22fed37..aa318d254d47 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -36,7 +36,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( @@ -287,7 +287,7 @@ def forward( return attn_output, attn_weights -class LlamaDecoderLayer(GradientCheckpointLayer): +class LlamaDecoderLayer(nn.ModuleLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2a4dcbefc344..d0afcad98660 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -20,7 +20,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_mistral import MistralConfig @@ -261,7 +261,7 @@ def forward( return attn_output, attn_weights -class MistralDecoderLayer(GradientCheckpointLayer): +class MistralDecoderLayer(nn.ModuleLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size From 9c83d96955823dfbc43c3c4ad5c5cb670f418d76 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 16 Dec 2024 13:51:21 +0100 Subject: [PATCH 24/97] make fix-copies --- .../modular-transformers/modeling_dummy.py | 4 +- .../modeling_multimodal1.py | 4 +- .../modeling_my_new_model2.py | 4 +- .../modular-transformers/modeling_super.py | 4 +- .../models/cohere/modeling_cohere.py | 67 +++------- src/transformers/models/dbrx/modeling_dbrx.py | 58 +++++++-- .../modeling_decision_transformer.py | 106 +++++----------- .../models/falcon/modeling_falcon.py | 34 +----- .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 4 +- .../models/gpt_neox/modeling_gpt_neox.py | 34 +----- .../modeling_gpt_neox_japanese.py | 34 +----- .../models/granitemoe/modeling_granitemoe.py | 108 ++++++---------- .../models/idefics/modeling_idefics.py | 11 +- .../models/jamba/modeling_jamba.py | 6 +- .../models/jetmoe/modeling_jetmoe.py | 58 +++++++-- src/transformers/models/mimi/modeling_mimi.py | 115 ++++++++++-------- .../models/moshi/modeling_moshi.py | 62 +++++++--- .../models/nemotron/modeling_nemotron.py | 26 +--- src/transformers/models/olmo/modeling_olmo.py | 4 +- .../models/olmo2/modeling_olmo2.py | 4 +- .../models/olmoe/modeling_olmoe.py | 43 ++----- .../models/persimmon/modeling_persimmon.py | 34 +----- src/transformers/models/phi3/modeling_phi3.py | 58 +++++++-- .../models/phimoe/modeling_phimoe.py | 15 +-- .../models/pixtral/modeling_pixtral.py | 6 +- .../models/qwen2/modeling_qwen2.py | 4 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 94 +++++--------- .../models/qwen2_vl/modeling_qwen2_vl.py | 12 +- .../modeling_recurrent_gemma.py | 16 ++- .../models/stablelm/modeling_stablelm.py | 40 ++---- .../models/starcoder2/modeling_starcoder2.py | 4 +- .../models/zamba/modeling_zamba.py | 6 +- 33 files changed, 479 insertions(+), 604 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 1d336845de23..82973ed379f3 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -15,7 +15,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_dummy import DummyConfig @@ -252,7 +252,7 @@ def forward( return attn_output, attn_weights -class DummyDecoderLayer(GradientCheckpointLayer): +class DummyDecoderLayer(nn.ModuleLayer): def __init__(self, config: DummyConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 78e37aa123d1..cdcb68bd1afb 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -15,7 +15,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_multimodal1 import Multimodal1TextConfig @@ -252,7 +252,7 @@ def forward( return attn_output, attn_weights -class Multimodal1TextDecoderLayer(GradientCheckpointLayer): +class Multimodal1TextDecoderLayer(nn.ModuleLayer): def __init__(self, config: Multimodal1TextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index cc26a94defc6..d79adc771f1f 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -15,7 +15,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_my_new_model2 import MyNewModel2Config @@ -187,7 +187,7 @@ def forward( return attn_output, attn_weights -class MyNewModel2DecoderLayer(GradientCheckpointLayer): +class MyNewModel2DecoderLayer(nn.ModuleLayer): def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index cfbaa5ac6e64..273ad3e68074 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -15,7 +15,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward from .configuration_super import SuperConfig @@ -249,7 +249,7 @@ def forward( return attn_output, attn_weights -class SuperDecoderLayer(GradientCheckpointLayer): +class SuperDecoderLayer(nn.ModuleLayer): def __init__(self, config: SuperConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 72f6f1c82e86..9e793308afb8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -827,31 +827,22 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -860,42 +851,25 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -905,18 +879,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 990b4d821e5b..4fc263214bf5 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -48,24 +48,55 @@ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx class DbrxRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[DbrxConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -73,6 +104,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index b8eb9f5a8b42..6829bf31aefc 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -161,46 +161,6 @@ def prune_heads(self, heads): self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads) - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - ) - - # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) - - if not self.is_cross_attention: - # if only "normal" attention layer implements causal mask - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) - - if attention_mask is not None: - # Apply the attention mask - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - # Mask heads if we want to - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() @@ -253,32 +213,17 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea return attn_output, attn_weights - def _split_heads(self, tensor, num_heads, attn_head_size): - """ - Splits hidden_size dim into attn_head_size and num_heads - """ - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) - - def _merge_heads(self, tensor, num_heads, attn_head_size): - """ - Merges attn_head_size dim and num_attn_heads dim into hidden_size - """ - tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): @@ -293,34 +238,39 @@ def forward( else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - if layer_past is not None: - past_key, past_value = layer_past - key = torch.cat((past_key, key), dim=-2) - value = torch.cat((past_value, value), dim=-2) + query_states = query.view(hidden_shape).transpose(1, 2) + key_states = key.view(hidden_shape).transpose(1, 2) + value_states = value.view(hidden_shape).transpose(1, 2) - if use_cache is True: - present = (key, value) - else: - present = None + cos, sin = position_embeddings + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) else: - attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) - attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2 diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b2ca3ffed67b..e2f609bc348a 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -113,40 +113,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class FalconRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[FalconConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`FalconRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ced8b374ba04..389d31badfbb 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -36,7 +36,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, @@ -217,7 +217,7 @@ def forward( return attn_output, attn_weights -class GemmaDecoderLayer(GradientCheckpointLayer): +class GemmaDecoderLayer(nn.ModuleLayer): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 3f66f2254cf0..e708c686f1c4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -35,7 +35,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, @@ -301,7 +301,7 @@ def forward( return attn_output, attn_weights -class Gemma2DecoderLayer(GradientCheckpointLayer): +class Gemma2DecoderLayer(nn.ModuleLayer): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 70ff07ed7f6d..c0559d0e4cbc 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -490,40 +490,18 @@ def __init__(self, config, layer_idx=None): class GPTNeoXRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[GPTNeoXConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`GPTNeoXRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index c9e1b2d72135..498573f065d1 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -227,40 +227,18 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): class GPTNeoXJapaneseRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[GPTNeoXJapaneseConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`GPTNeoXJapaneseRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 6d706c463270..7da21d7dba38 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -158,10 +158,14 @@ def extra_repr(self): # Copied from transformers.models.granite.modeling_granite.GraniteRotaryEmbedding with Granite->GraniteMoe class GraniteMoeRotaryEmbedding(nn.Module): - def __init__(self, config: GraniteMoeConfig): + def __init__( + self, + device=None, + config: Optional[GraniteMoeConfig] = None, + ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} + # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: @@ -172,7 +176,7 @@ def __init__(self, config: GraniteMoeConfig): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -421,55 +425,37 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.is_causal = True - + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.attention_multiplier - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -479,35 +465,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8bd24728b038..ea3818990c7c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -445,7 +445,7 @@ def rotate_half(x): # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -453,9 +453,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -466,8 +465,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 07643d243e2a..ae7470d789b2 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -834,6 +834,7 @@ def forward( class JambaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -841,8 +842,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index cf30e9c02405..17b9c28135de 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -385,24 +385,55 @@ def extra_repr(self): # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe class JetMoeRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[JetMoeConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -410,6 +441,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 3ba64a554831..333dde03e0ce 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -364,24 +364,55 @@ def forward(self, x: torch.Tensor): # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi class MimiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[MimiConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -389,6 +420,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -503,24 +539,19 @@ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -528,35 +559,21 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) - attn_output = attn_output.view(bsz, q_len, -1) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index f3d2f2d9d93d..30715f9ec267 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -307,24 +307,55 @@ def forward(self, x, layer_idx=None): # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi class MoshiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[MoshiConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -332,6 +363,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -466,12 +502,10 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_fle def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a7cb80a19433..39569846b78b 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -517,7 +517,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron -class NemotronDecoderLayer(nn.Module): +class NemotronDecoderLayer(nn.ModuleLayer): # Ignore copy def __init__(self, config: NemotronConfig, layer_idx: int): super().__init__() @@ -539,30 +539,8 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f634c52f2fc7..f1a411009c20 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -23,7 +23,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, @@ -228,7 +228,7 @@ def forward( return attn_output, attn_weights -class OlmoDecoderLayer(GradientCheckpointLayer): +class OlmoDecoderLayer(nn.ModuleLayer): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 31cb6b487144..39f6df4b6f7e 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -16,7 +16,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, @@ -204,7 +204,7 @@ def forward(self, x): return down_proj -class Olmo2DecoderLayer(GradientCheckpointLayer): +class Olmo2DecoderLayer(nn.ModuleLayer): def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index a4b20ee1207e..e6facc8093c3 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -160,40 +160,18 @@ def extra_repr(self): class OlmoeRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[OlmoeConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -287,13 +265,14 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.repeat_kv diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 884ee4d86aaf..765711711ca8 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -59,40 +59,18 @@ class PersimmonRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[PersimmonConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index d9d7db63e27f..0ba55f254f6a 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -76,24 +76,55 @@ def extra_repr(self): # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 class Phi3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[Phi3Config] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -101,6 +132,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 82763ccea62e..496322096444 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -187,7 +187,7 @@ def rotate_half(x): # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -195,9 +195,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -208,8 +207,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -912,10 +911,12 @@ class PhimoePreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PhimoeDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index b65fbd634ba7..03886d4a5284 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -216,6 +216,7 @@ def forward( class PixtralMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -223,8 +224,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 24c1c6ce3e65..78db5f2117e1 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -22,7 +22,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...utils import ( LossKwargs, add_code_sample_docstrings, @@ -212,7 +212,7 @@ def forward(self, x): return down_proj -class Qwen2DecoderLayer(GradientCheckpointLayer): +class Qwen2DecoderLayer(nn.ModuleLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0de6d3a1a727..d3e431ce723a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -169,40 +169,18 @@ def extra_repr(self): class Qwen2MoeRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[Qwen2MoeConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2MoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -320,43 +298,28 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe class Qwen2MoeAttention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ + """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): + def __init__(self, config: Qwen2MoeConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config) + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) # Ignore copy def forward( @@ -1577,22 +1540,21 @@ def forward( ) # Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): - base_model_prefix = "model" + base_model_prefix = "transformer" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe def __init__(self, config): super().__init__(config) - self.model = Qwen2MoeModel(config) + self.transformer = Qwen2MoeModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.transformer.embed_tokens def set_input_embeddings(self, value): - self.model.embed_tokens = value + self.transformer.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( @@ -1621,7 +1583,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dce0702b0819..26c489dc2a76 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -460,15 +460,17 @@ def extra_repr(self): class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.repeat_kv diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 2b3cf7eb0cb8..b9d720eb39dc 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -78,13 +78,14 @@ def __init__(self, dim, base=10000, device=None): @torch.no_grad() # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -92,6 +93,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 904f4a303ade..2abde43770f8 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -65,40 +65,18 @@ class StableLmRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[StableLmConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`StableLmRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -189,6 +167,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class StableLmMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -196,8 +175,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class StableLmLayerNormPerHead(nn.Module): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ee515d77611f..e5d831f7e1f5 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -41,7 +41,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, @@ -272,7 +272,7 @@ def forward( return attn_output, attn_weights -class Starcoder2DecoderLayer(GradientCheckpointLayer): +class Starcoder2DecoderLayer(nn.ModuleLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index e49a921e3415..3b7348eadd47 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -773,6 +773,7 @@ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache class ZambaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -780,8 +781,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class ZambaAttentionDecoderLayer(nn.Module): From 138368ecce89e7054048929de0ac9c03937b1792 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 16 Dec 2024 14:00:22 +0100 Subject: [PATCH 25/97] oups --- examples/modular-transformers/modeling_my_new_model2.py | 2 +- examples/modular-transformers/modeling_super.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index d79adc771f1f..9da2f5a51511 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -15,7 +15,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_my_new_model2 import MyNewModel2Config diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 273ad3e68074..6e6d11be2535 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -15,7 +15,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward from .configuration_super import SuperConfig diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e708c686f1c4..cb001538aabe 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -35,7 +35,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f23130db819b..7d732d8a38c7 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -36,7 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( add_code_sample_docstrings, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index aa318d254d47..734d0b2d9cad 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -36,7 +36,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f1a411009c20..0fd9564ad7f5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -23,7 +23,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 39f6df4b6f7e..33eff916ba05 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -16,7 +16,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 78db5f2117e1..ed275327036a 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -22,7 +22,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ( LossKwargs, add_code_sample_docstrings, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index e5d831f7e1f5..ed2737c9ab7f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -41,7 +41,7 @@ TokenClassifierOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, nn.ModuleLayer, PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, From 7225a4f33cca0b68b937b5f8b5456b050821cdf1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 16 Dec 2024 14:01:50 +0100 Subject: [PATCH 26/97] oups2.0 --- .../modular-transformers/modeling_dummy.py | 2 +- .../modeling_multimodal1.py | 2 +- .../modeling_my_new_model2.py | 2 +- .../modular-transformers/modeling_super.py | 2 +- .../models/gemma/modeling_gemma.py | 69 +------------------ .../models/gemma2/modeling_gemma2.py | 11 ++- src/transformers/models/glm/modeling_glm.py | 61 ++++++++++++---- .../models/llama/modeling_llama.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- .../models/olmo2/modeling_olmo2.py | 2 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- 14 files changed, 66 insertions(+), 97 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 82973ed379f3..4460caf85f67 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -252,7 +252,7 @@ def forward( return attn_output, attn_weights -class DummyDecoderLayer(nn.ModuleLayer): +class DummyDecoderLayer(nn.Module): def __init__(self, config: DummyConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index cdcb68bd1afb..abc11092d4c2 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -252,7 +252,7 @@ def forward( return attn_output, attn_weights -class Multimodal1TextDecoderLayer(nn.ModuleLayer): +class Multimodal1TextDecoderLayer(nn.Module): def __init__(self, config: Multimodal1TextConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 9da2f5a51511..c83cc76f4581 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -187,7 +187,7 @@ def forward( return attn_output, attn_weights -class MyNewModel2DecoderLayer(nn.ModuleLayer): +class MyNewModel2DecoderLayer(nn.Module): def __init__(self, config: MyNewModel2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 6e6d11be2535..80129ce40c3a 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -249,7 +249,7 @@ def forward( return attn_output, attn_weights -class SuperDecoderLayer(nn.ModuleLayer): +class SuperDecoderLayer(nn.Module): def __init__(self, config: SuperConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 389d31badfbb..1eb08264e714 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -35,7 +35,6 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -217,7 +216,7 @@ def forward( return attn_output, attn_weights -class GemmaDecoderLayer(nn.ModuleLayer): +class GemmaDecoderLayer(nn.Module): def __init__(self, config: GemmaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -320,71 +319,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -class GemmaRotaryEmbedding(nn.Module): - def __init__( - self, - device=None, - config: Optional[GemmaConfig] = None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -482,7 +416,6 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = GemmaRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index cb001538aabe..bc8b0f23e37a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -301,17 +301,16 @@ def forward( return attn_output, attn_weights -class Gemma2DecoderLayer(nn.ModuleLayer): +class Gemma2DecoderLayer(nn.Module): def __init__(self, config: Gemma2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - + self.config = config + self.is_sliding = not bool(layer_idx % 2) self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx) - self.mlp = Gemma2MLP(config) - self.input_layernorm = Gemma2RMSNorm(config.hidden_size) - self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size) - self.is_sliding = not bool(layer_idx % 2) + self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 7d732d8a38c7..a54272cc710b 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -36,6 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -75,24 +76,55 @@ def extra_repr(self): class GlmRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[GlmConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -100,6 +132,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -262,7 +299,7 @@ def forward( return attn_output, attn_weights -class GlmDecoderLayer(nn.ModuleLayer): +class GlmDecoderLayer(nn.Module): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 734d0b2d9cad..134cbb3e110d 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -287,7 +287,7 @@ def forward( return attn_output, attn_weights -class LlamaDecoderLayer(nn.ModuleLayer): +class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d0afcad98660..c10f79f6bcc1 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -261,7 +261,7 @@ def forward( return attn_output, attn_weights -class MistralDecoderLayer(nn.ModuleLayer): +class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 39569846b78b..40f588f8f1a2 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -517,7 +517,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron -class NemotronDecoderLayer(nn.ModuleLayer): +class NemotronDecoderLayer(nn.Module): # Ignore copy def __init__(self, config: NemotronConfig, layer_idx: int): super().__init__() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 0fd9564ad7f5..ddd0faabb4c0 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -228,7 +228,7 @@ def forward( return attn_output, attn_weights -class OlmoDecoderLayer(nn.ModuleLayer): +class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 33eff916ba05..606d4041a6ca 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -204,7 +204,7 @@ def forward(self, x): return down_proj -class Olmo2DecoderLayer(nn.ModuleLayer): +class Olmo2DecoderLayer(nn.Module): def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index ed275327036a..d5968af93fbe 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -212,7 +212,7 @@ def forward(self, x): return down_proj -class Qwen2DecoderLayer(nn.ModuleLayer): +class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ed2737c9ab7f..c5dd8cc815e2 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -272,7 +272,7 @@ def forward( return attn_output, attn_weights -class Starcoder2DecoderLayer(nn.ModuleLayer): +class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size From a3b9195f37d4d54c2d48863fcac5bf568afc235c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 14:02:19 +0100 Subject: [PATCH 27/97] fix some modulars + all copied from --- .../models/cohere/modeling_cohere.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 58 ++++++-- .../modeling_decision_transformer.py | 3 +- .../models/falcon/modeling_falcon.py | 34 +---- src/transformers/models/glm/modeling_glm.py | 59 ++++++-- .../models/gpt_neox/modeling_gpt_neox.py | 34 +---- .../modeling_gpt_neox_japanese.py | 34 +---- .../models/granitemoe/modeling_granitemoe.py | 13 +- .../models/idefics/modeling_idefics.py | 3 +- .../models/jamba/modeling_jamba.py | 6 +- .../models/jetmoe/modeling_jetmoe.py | 58 ++++++-- src/transformers/models/mimi/modeling_mimi.py | 59 ++++++-- .../models/mistral/modeling_mistral.py | 135 +++++++++++++++++- .../models/mistral/modular_mistral.py | 5 + .../models/moshi/modeling_moshi.py | 59 ++++++-- .../models/nemotron/modeling_nemotron.py | 3 +- src/transformers/models/olmo/modeling_olmo.py | 6 +- src/transformers/models/olmo/modular_olmo.py | 6 +- .../models/olmo2/modeling_olmo2.py | 6 +- .../models/olmoe/modeling_olmoe.py | 37 ++--- .../models/persimmon/modeling_persimmon.py | 34 +---- src/transformers/models/phi3/modeling_phi3.py | 58 ++++++-- .../models/phimoe/modeling_phimoe.py | 7 +- .../models/pixtral/modeling_pixtral.py | 6 +- .../models/qwen2/modeling_qwen2.py | 32 ++--- .../models/qwen2/modular_qwen2.py | 10 ++ .../models/qwen2_moe/modeling_qwen2_moe.py | 48 ++----- .../models/qwen2_vl/modeling_qwen2_vl.py | 6 +- .../modeling_recurrent_gemma.py | 16 ++- .../models/stablelm/modeling_stablelm.py | 40 ++---- .../models/zamba/modeling_zamba.py | 6 +- 31 files changed, 565 insertions(+), 319 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 72f6f1c82e86..2686879fe9b0 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -761,7 +761,8 @@ def _init_weights(self, module): "The bare Cohere Model outputting raw hidden-states without any specific head on top.", COHERE_START_DOCSTRING, ) -# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE +# copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->Cohere, LLAMA->COHERE +# TODO cyril: modular class CohereModel(CoherePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`CohereDecoderLayer`] diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 990b4d821e5b..4fc263214bf5 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -48,24 +48,55 @@ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx class DbrxRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[DbrxConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -73,6 +104,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index b8eb9f5a8b42..7e8c73d025dd 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -100,7 +100,8 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model -# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 +# copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 +# no longer copied after attention refactors class DecisionTransformerGPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b2ca3ffed67b..e2f609bc348a 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -113,40 +113,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class FalconRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[FalconConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`FalconRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0f14c7140efe..200c6b5b937a 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -36,6 +36,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -75,24 +76,55 @@ def extra_repr(self): class GlmRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[GlmConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -100,6 +132,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 70ff07ed7f6d..c0559d0e4cbc 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -490,40 +490,18 @@ def __init__(self, config, layer_idx=None): class GPTNeoXRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[GPTNeoXConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`GPTNeoXRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index c9e1b2d72135..498573f065d1 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -227,40 +227,18 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): class GPTNeoXJapaneseRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[GPTNeoXJapaneseConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`GPTNeoXJapaneseRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 6d706c463270..ce36905ba23c 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -158,10 +158,14 @@ def extra_repr(self): # Copied from transformers.models.granite.modeling_granite.GraniteRotaryEmbedding with Granite->GraniteMoe class GraniteMoeRotaryEmbedding(nn.Module): - def __init__(self, config: GraniteMoeConfig): + def __init__( + self, + device=None, + config: Optional[GraniteMoeConfig] = None, + ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} + # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: @@ -172,7 +176,7 @@ def __init__(self, config: GraniteMoeConfig): self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=None, **self.rope_kwargs) + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @@ -413,7 +417,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe +# copied from transformers.models.granite.modeling_granite.GraniteAttention with Granite->GraniteMoe +# no longer copied after attention refactors class GraniteMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8bd24728b038..3489f931f890 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -444,7 +444,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# no longer copied after attention refactors def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 07643d243e2a..ae7470d789b2 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -834,6 +834,7 @@ def forward( class JambaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -841,8 +842,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index cf30e9c02405..17b9c28135de 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -385,24 +385,55 @@ def extra_repr(self): # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->JetMoe class JetMoeRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[JetMoeConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -410,6 +441,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 3ba64a554831..1c5d1ea30a16 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -364,24 +364,55 @@ def forward(self, x: torch.Tensor): # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mimi class MimiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[MimiConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -389,6 +420,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -457,7 +493,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +# copied from transformers.models.gemma.modeling_gemma.GemmaAttention with Gemma->Mimi +# no longer copied after attention refactors class MimiAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2a4dcbefc344..90dd24c653e9 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -11,10 +11,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPast, + CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, @@ -22,7 +24,14 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, GradientCheckpointLayer, PreTrainedModel from ...processing_utils import Unpack -from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ...utils import ( + LossKwargs, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) from .configuration_mistral import MistralConfig @@ -718,6 +727,130 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @add_start_docstrings( """ The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index fd0c86218050..6e6a18f9ef65 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -5,6 +5,7 @@ from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ..llama.modeling_llama import ( + LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, @@ -178,6 +179,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class MistralForCausalLM(LlamaForCausalLM): + pass + + class MistralForTokenClassification(LlamaForTokenClassification): pass diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index f3d2f2d9d93d..7b1e5bef0309 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -307,24 +307,55 @@ def forward(self, x, layer_idx=None): # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi class MoshiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[MoshiConfig] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward - # TODO(joao): add me back asap :) def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -332,6 +363,11 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -462,7 +498,8 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_fle base=self.rope_theta, ) - # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward + # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward + # no longer copied after attention refactors def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index a7cb80a19433..a0a10bdc6f35 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -516,7 +516,8 @@ def forward( } -# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +# no longer copied after attention refactors class NemotronDecoderLayer(nn.Module): # Ignore copy def __init__(self, config: NemotronConfig, layer_idx: int): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f634c52f2fc7..42317c54d9eb 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -82,9 +82,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index e43976969ace..232535a8c14f 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -41,7 +41,11 @@ class OlmoRMSNorm(LlamaRMSNorm): class OlmoMLP(LlamaMLP): - pass + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) class OlmoAttention(LlamaAttention): diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 31cb6b487144..5273a0e63b84 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -194,9 +194,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index a4b20ee1207e..7ceaa3622b6e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -160,40 +160,18 @@ def extra_repr(self): class OlmoeRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[OlmoeConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`OlmoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -293,7 +271,8 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.repeat_kv diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 884ee4d86aaf..765711711ca8 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -59,40 +59,18 @@ class PersimmonRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[PersimmonConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index d9d7db63e27f..0ba55f254f6a 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -76,24 +76,55 @@ def extra_repr(self): # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 class Phi3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, + device=None, + config: Optional[Phi3Config] = None, + ): super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) - self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -101,6 +132,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 82763ccea62e..9ece8ecda027 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -186,7 +186,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb +# no longer copied after attention refactors def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -912,10 +913,12 @@ class PhimoePreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PhimoeDecoderLayer"] - _skip_keys_device_placement = "past_key_values" + _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index b65fbd634ba7..03886d4a5284 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -216,6 +216,7 @@ def forward( class PixtralMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -223,8 +224,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Pixtral diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 24c1c6ce3e65..493ac34e5386 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -40,6 +40,22 @@ _CONFIG_FOR_DOC = "Qwen2Config" +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -196,22 +212,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - class Qwen2DecoderLayer(GradientCheckpointLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index edadaa6e1188..f96bfbe1a98f 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -2,6 +2,7 @@ import torch import torch.utils.checkpoint +from torch import nn from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -14,6 +15,7 @@ LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, + LlamaMLP, LlamaModel, apply_rotary_pos_emb, eager_attention_forward, @@ -24,6 +26,14 @@ logger = logging.get_logger(__name__) +class Qwen2MLP(LlamaMLP): + def __init__(self, config): + super().__init__(config) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + class Qwen2Attention(LlamaAttention): def forward( self, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0de6d3a1a727..a3c0d36e2a5a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -169,40 +169,18 @@ def extra_repr(self): class Qwen2MoeRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[Qwen2MoeConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`Qwen2MoeRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -318,7 +296,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe +# copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe +# no longer copied after attention refactors class Qwen2MoeAttention(nn.Module): """ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer @@ -1577,22 +1556,21 @@ def forward( ) # Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): - base_model_prefix = "model" + base_model_prefix = "transformer" - # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Qwen2Moe def __init__(self, config): super().__init__(config) - self.model = Qwen2MoeModel(config) + self.transformer = Qwen2MoeModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.model.embed_tokens + return self.transformer.embed_tokens def set_input_embeddings(self, value): - self.model.embed_tokens = value + self.transformer.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( @@ -1621,7 +1599,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.model( + outputs = self.transformer( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dce0702b0819..10c9b1638548 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -460,6 +460,7 @@ def extra_repr(self): class Qwen2MLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -467,8 +468,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Copied from transformers.models.llama.modeling_llama.repeat_kv diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 2b3cf7eb0cb8..b9d720eb39dc 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -78,13 +78,14 @@ def __init__(self, dim, base=10000, device=None): @torch.no_grad() # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma - def forward(self, x, position_ids, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - self.inv_freq.to(x.device) + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -92,6 +93,11 @@ def forward(self, x, position_ids, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 904f4a303ade..2abde43770f8 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -65,40 +65,18 @@ class StableLmRotaryEmbedding(nn.Module): def __init__( self, - dim=None, - max_position_embeddings=2048, - base=10000, device=None, - scaling_factor=1.0, - rope_type="default", config: Optional[StableLmConfig] = None, ): super().__init__() - # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`StableLmRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.46" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] @@ -189,6 +167,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class StableLmMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -196,8 +175,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class StableLmLayerNormPerHead(nn.Module): diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index e49a921e3415..3b7348eadd47 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -773,6 +773,7 @@ def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache class ZambaMLP(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -780,8 +781,9 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] - def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj class ZambaAttentionDecoderLayer(nn.Module): From 8d93708ea6e22c1728cd6734a690dbf0c2c7d6ab Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 16 Dec 2024 14:09:59 +0100 Subject: [PATCH 28/97] should be good now --- .../modular-transformers/modeling_dummy.py | 35 +++-- .../modeling_multimodal1.py | 35 +++-- .../modeling_my_new_model2.py | 5 + src/transformers/models/aria/modeling_aria.py | 35 +++-- .../models/cohere/modeling_cohere.py | 35 +++-- .../models/gemma/modeling_gemma.py | 72 ++++++++++ .../models/gemma/modular_gemma.py | 6 +- .../models/gemma2/modeling_gemma2.py | 130 +++++++++--------- .../models/gemma2/modular_gemma2.py | 4 - src/transformers/models/glm/modeling_glm.py | 35 +++-- .../models/granite/modeling_granite.py | 35 +++-- .../models/llama/modeling_llama.py | 35 +++-- .../models/mistral/modeling_mistral.py | 35 +++-- src/transformers/models/olmo/modeling_olmo.py | 35 +++-- .../models/olmo2/modeling_olmo2.py | 35 +++-- src/transformers/models/phi/modeling_phi.py | 35 +++-- .../models/qwen2/modeling_qwen2.py | 35 +++-- 17 files changed, 435 insertions(+), 202 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 4460caf85f67..ea05b9f5fa1a 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -526,17 +526,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index abc11092d4c2..4b8ca8a6689b 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -526,17 +526,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index c83cc76f4581..404b83288aa9 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -526,6 +526,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # MyNewModel2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -551,6 +554,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -561,6 +565,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 932fff4e528f..31a6bcefb6f2 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -942,17 +942,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9e793308afb8..bf37d1ff9d5d 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -856,17 +856,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 1eb08264e714..2b92e21ace96 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -35,6 +35,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -319,6 +320,71 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +class GemmaRotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[GemmaConfig] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -416,6 +482,7 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = GemmaRotaryEmbedding(config=config) # Initialize weights and apply final processing self.post_init() @@ -489,6 +556,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -514,6 +584,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -524,6 +595,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index f86f0c5373d8..a73ea00f46f9 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -367,7 +367,6 @@ def __init__(self, config: GemmaConfig): [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet! self.post_init() def forward( @@ -432,6 +431,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -457,6 +459,7 @@ def forward( output_attentions, use_cache, cache_position, + position_embeddings ) else: layer_outputs = decoder_layer( @@ -467,6 +470,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index bc8b0f23e37a..0688b3756c46 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -89,71 +89,6 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -class Gemma2RotaryEmbedding(nn.Module): - def __init__( - self, - device=None, - config: Optional[Gemma2Config] = None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -419,6 +354,71 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() +class Gemma2RotaryEmbedding(nn.Module): + def __init__( + self, + device=None, + config: Optional[Gemma2Config] = None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + GEMMA2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index b89596f03eb9..9d6f4bfef862 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -202,10 +202,6 @@ def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -class Gemma2RotaryEmbedding(GemmaRotaryEmbedding): - pass - - def eager_attention_forward( module: nn.Module, query: torch.Tensor, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index a54272cc710b..15662c360fc5 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -578,17 +578,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 9ea2c2c3ec16..f7aebb50d967 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -573,17 +573,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 134cbb3e110d..35d7ce410a86 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -561,17 +561,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c10f79f6bcc1..ae3caa046295 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -535,17 +535,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index ddd0faabb4c0..740671dd2ea5 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -566,17 +566,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 606d4041a6ca..2325af48d9a2 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -536,17 +536,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 64364d54b3a4..b3167433b3e7 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -565,17 +565,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index d5968af93fbe..62dadabb5e20 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -556,17 +556,30 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **flash_attn_kwargs, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) hidden_states = layer_outputs[0] From 54d9b95485cf1e5d612bb945635145ccaf2eeb12 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 14:30:26 +0100 Subject: [PATCH 29/97] revert unwanted changes --- .../models/gemma/modular_gemma.py | 2 +- .../models/gemma2/modular_gemma2.py | 2 - .../models/granitemoe/modeling_granitemoe.py | 98 ++++++++++++------- src/transformers/models/mimi/modeling_mimi.py | 59 +++++++---- .../models/moshi/modeling_moshi.py | 6 +- .../models/nemotron/modeling_nemotron.py | 24 ++++- .../models/olmoe/modeling_olmoe.py | 6 +- .../models/qwen2/modeling_qwen2.py | 16 --- .../models/qwen2_moe/modeling_qwen2_moe.py | 49 ++++++---- .../models/qwen2_vl/modeling_qwen2_vl.py | 6 +- 10 files changed, 170 insertions(+), 98 deletions(-) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index a73ea00f46f9..b2a0fcdecf6c 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -459,7 +459,7 @@ def forward( output_attentions, use_cache, cache_position, - position_embeddings + position_embeddings, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 9d6f4bfef862..b166ea72957e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -32,14 +32,12 @@ from ...utils import logging from ..gemma.modeling_gemma import ( GemmaAttention, - GemmaDecoderLayer, GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, GemmaModel, GemmaPreTrainedModel, GemmaRMSNorm, - GemmaRotaryEmbedding, apply_rotary_pos_emb, repeat_kv, ) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 5fc5294b83f7..ce36905ba23c 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -426,37 +426,55 @@ def __init__(self, config: GraniteMoeConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.scaling = config.attention_multiplier - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -466,21 +484,35 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - **kwargs, - ) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value # NO LONGER EXIST Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 5fa72a1bd31b..1c5d1ea30a16 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -540,19 +540,24 @@ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - cos, sin = position_embeddings + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -560,21 +565,35 @@ def forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - **kwargs, - ) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value # NO LONGER EXIST Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index a5501aab2bcf..7b1e5bef0309 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -503,10 +503,12 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_fle def forward( self, hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 64fc240980b1..a0a10bdc6f35 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -540,8 +540,30 @@ def forward( use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index e6facc8093c3..7ceaa3622b6e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -265,9 +265,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f5ddd5c81558..b98bd67252e8 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -212,22 +212,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index b86cf0100c4c..a3c0d36e2a5a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -299,28 +299,43 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: # copied from transformers.models.qwen2.modeling_qwen2.Qwen2Attention with Qwen2->Qwen2Moe # no longer copied after attention refactors class Qwen2MoeAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ - def __init__(self, config: Qwen2MoeConfig, layer_idx: int): + def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config) # Ignore copy def forward( diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 26c489dc2a76..10c9b1638548 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -463,9 +463,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): From 944e26e9a605a0d714d28785ce46ed48cbab8000 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 14:47:46 +0100 Subject: [PATCH 30/97] Update modeling_decision_transformer.py --- .../modeling_decision_transformer.py | 48 +++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index dbcad37c02cb..445bc5d09503 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -17,7 +17,7 @@ import math import os from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +25,7 @@ from ...activations import ACT2FN from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...utils import ( ModelOutput, @@ -100,8 +100,48 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model -# copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 -# no longer copied after attention refactors +# Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward +def eager_attention_forward(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + +# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 class DecisionTransformerGPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__() From 911833f86f45350f89d0abee7316a19b89becd1d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 14:51:16 +0100 Subject: [PATCH 31/97] finish cleanup --- src/transformers/models/dbrx/modeling_dbrx.py | 1 + .../modeling_decision_transformer.py | 1 + src/transformers/models/jetmoe/modeling_jetmoe.py | 1 + src/transformers/models/mimi/modeling_mimi.py | 1 + src/transformers/models/moshi/modeling_moshi.py | 1 + src/transformers/models/olmo/modular_olmo.py | 15 --------------- src/transformers/models/phi3/modeling_phi3.py | 1 + 7 files changed, 6 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 4fc263214bf5..9294fff11db5 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -26,6 +26,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 445bc5d09503..1bc96f19bb39 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -141,6 +141,7 @@ def eager_attention_forward(self, query, key, value, attention_mask=None, head_m return attn_output, attn_weights + # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 class DecisionTransformerGPT2Attention(nn.Module): def __init__(self, config, is_cross_attention=False, layer_idx=None): diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 17b9c28135de..499b15a382a5 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -32,6 +32,7 @@ MoeModelOutputWithPast, SequenceClassifierOutputWithPast, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 1c5d1ea30a16..68ee9770d5a7 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 7b1e5bef0309..5fc97e3e2f73 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -36,6 +36,7 @@ ModelOutput, Seq2SeqLMOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import ( diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 232535a8c14f..72fc011e44f8 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -11,9 +11,6 @@ LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, LlamaMLP, LlamaRMSNorm, apply_rotary_pos_emb, @@ -111,15 +108,3 @@ def __init__(self, config: OlmoConfig, layer_idx: int): class OlmoForCausalLM(LlamaForCausalLM): pass - - -class OlmoForTokenClassification(LlamaForTokenClassification): - pass - - -class OlmoForSequenceClassification(LlamaForSequenceClassification): - pass - - -class OlmoForQuestionAnswering(LlamaForQuestionAnswering): - pass diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 0ba55f254f6a..35f26971003b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -35,6 +35,7 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, From ea2691097c42a411ccf823a26f41fe22d2e4a009 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 14:52:28 +0100 Subject: [PATCH 32/97] Update modeling_olmo.py --- src/transformers/models/olmo/modeling_olmo.py | 289 +----------------- 1 file changed, 1 insertion(+), 288 deletions(-) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index d42e69ce71a1..1a45a34734f1 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -15,19 +15,12 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - QuestionAnsweringModelOutput, - SequenceClassifierOutputWithPast, - TokenClassifierOutput, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( LossKwargs, - add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging, @@ -37,8 +30,6 @@ logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "meta-olmo/Olmo-2-7b-hf" _CONFIG_FOR_DOC = "OlmoConfig" @@ -854,281 +845,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - -@add_start_docstrings( - """ - The Olmo Model transformer with a token classification head on top (a linear layer on top of the hidden-states - output) e.g. for Named-Entity-Recognition (NER) tasks. - """, - OLMO_START_DOCSTRING, -) -class OlmoForTokenClassification(OlmoPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = OlmoModel(config) - if getattr(config, "classifier_dropout", None) is not None: - classifier_dropout = config.classifier_dropout - elif getattr(config, "hidden_dropout", None) is not None: - classifier_dropout = config.hidden_dropout - else: - classifier_dropout = 0.1 - self.dropout = nn.Dropout(classifier_dropout) - self.score = nn.Linear(config.hidden_size, config.num_labels) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=TokenClassifierOutput, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, TokenClassifierOutput]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = outputs[0] - sequence_output = self.dropout(sequence_output) - logits = self.score(sequence_output) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.config) - - if not return_dict: - output = (logits,) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return TokenClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@add_start_docstrings( - """ - The Olmo Model transformer with a sequence classification head on top (linear layer). - - [`OlmoForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - OLMO_START_DOCSTRING, -) -class OlmoForSequenceClassification(OlmoPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = OlmoModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - -@add_start_docstrings( - """ -The Olmo Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - OLMO_START_DOCSTRING, -) -class OlmoForQuestionAnswering(OlmoPreTrainedModel): - base_model_prefix = "transformer" - - def __init__(self, config): - super().__init__(config) - self.transformer = OlmoModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.transformer.embed_tokens - - def set_input_embeddings(self, value): - self.transformer.embed_tokens = value - - @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) From bc421af3da8a5d6dd0e02dcbe69f3fc25512234c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 14:59:21 +0100 Subject: [PATCH 33/97] consistency --- src/transformers/models/phi/modeling_phi.py | 87 ------------------- src/transformers/models/phi/modular_phi.py | 5 -- .../models/qwen2/modeling_qwen2.py | 3 +- .../models/qwen2/modular_qwen2.py | 3 +- 4 files changed, 4 insertions(+), 94 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index b3167433b3e7..ab00d8c4c671 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -19,7 +19,6 @@ from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, - QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) @@ -1045,89 +1044,3 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) - - -@add_start_docstrings( - """ -The Phi Model transformer with a span classification head on top for extractive question-answering tasks like -SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). - """, - PHI_START_DOCSTRING, -) -class PhiForQuestionAnswering(PhiPreTrainedModel): - base_model_prefix = "transformer" - - def __init__(self, config): - super().__init__(config) - self.transformer = PhiModel(config) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.transformer.embed_tokens - - def set_input_embeddings(self, value): - self.transformer.embed_tokens = value - - @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - start_positions: Optional[torch.LongTensor] = None, - end_positions: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[Tuple, QuestionAnsweringModelOutput]: - r""" - start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the start of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for position (index) of the end of the labelled span for computing the token classification loss. - Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence - are not taken into account for computing the loss. - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - outputs = self.transformer( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - sequence_output = outputs[0] - - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1).contiguous() - end_logits = end_logits.squeeze(-1).contiguous() - - loss = None - if start_positions is not None and end_positions is not None: - loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) - - if not return_dict: - output = (start_logits, end_logits) + outputs[2:] - return ((loss,) + output) if loss is not None else output - - return QuestionAnsweringModelOutput( - loss=loss, - start_logits=start_logits, - end_logits=end_logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index c4036bee9452..051b9af61f9e 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -9,7 +9,6 @@ from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, - LlamaForQuestionAnswering, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaMLP, @@ -155,7 +154,3 @@ class PhiForSequenceClassification(LlamaForSequenceClassification): class PhiForTokenClassification(LlamaForTokenClassification): pass - - -class PhiForQuestionAnswering(LlamaForQuestionAnswering): - pass diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b98bd67252e8..4bc7a32df5f7 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_qwen2.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from typing import Callable, List, Optional, Tuple, Union, Unpack +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -23,6 +23,7 @@ ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( LossKwargs, add_code_sample_docstrings, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index f96bfbe1a98f..01609b0ca3a7 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -1,9 +1,10 @@ -from typing import Callable, Optional, Tuple, Unpack +from typing import Callable, Optional, Tuple import torch import torch.utils.checkpoint from torch import nn +from ...processing_utils import Unpack from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS From 8664ddcdc566e925726bbc3b72a00a8b13ca006e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 15:13:05 +0100 Subject: [PATCH 34/97] re-add gradient checkpointing attribute --- examples/modular-transformers/modeling_dummy.py | 1 + .../modeling_multimodal1.py | 1 + examples/modular-transformers/modeling_super.py | 1 + src/transformers/models/gemma/modeling_gemma.py | 1 + src/transformers/models/gemma/modular_gemma.py | 17 ++--------------- .../models/granite/modeling_granite.py | 1 + src/transformers/models/llama/modeling_llama.py | 3 ++- .../models/mistral/modeling_mistral.py | 1 + .../models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + src/transformers/models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/qwen2/modular_qwen2.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 1 + 15 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index ea05b9f5fa1a..7a049fa0b9e3 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -453,6 +453,7 @@ def __init__(self, config: DummyConfig): ) self.norm = DummyRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = DummyRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 4b8ca8a6689b..6bbd2f3dcad2 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -453,6 +453,7 @@ def __init__(self, config: Multimodal1TextConfig): ) self.norm = Multimodal1TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Multimodal1TextRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 80129ce40c3a..0fd187536ea4 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -450,6 +450,7 @@ def __init__(self, config: SuperConfig): ) self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = SuperRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2b92e21ace96..ca17084da1eb 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -483,6 +483,7 @@ def __init__(self, config: GemmaConfig): ) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index b2a0fcdecf6c..7295a0cbd3c2 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -401,19 +401,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -488,8 +477,6 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index f7aebb50d967..acbefc13d7ce 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -500,6 +500,7 @@ def __init__(self, config: GraniteConfig): ) self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GraniteRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 35d7ce410a86..f7128ba7afc2 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -488,7 +488,8 @@ def __init__(self, config: LlamaConfig): ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) - + self.gradient_checkpointing = False + # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 99c8335d8db1..4fa684b2d37e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -471,6 +471,7 @@ def __init__(self, config: MistralConfig): ) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MistralRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 53482843b3c0..ecac87a303c9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -585,6 +585,7 @@ def __init__(self, config: MixtralConfig): ) self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MixtralRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 1a45a34734f1..2e74dd39099e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -484,6 +484,7 @@ def __init__(self, config: OlmoConfig): ) self.norm = OlmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = OlmoRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 9330fb317e5a..a02afd91d752 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -463,6 +463,7 @@ def __init__(self, config: Olmo2Config): ) self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Olmo2RotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index ab00d8c4c671..512f88d8b1dc 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -491,6 +491,7 @@ def __init__(self, config: PhiConfig): ) self.norm = PhiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = PhiRotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 4bc7a32df5f7..8681bd83dc64 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -484,6 +484,7 @@ def __init__(self, config: Qwen2Config): ) self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen2RotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 01609b0ca3a7..7519db0877d8 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -4,10 +4,10 @@ import torch.utils.checkpoint from torch import nn -from ...processing_utils import Unpack from ...cache_utils import Cache from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import logging from ..llama.modeling_llama import ( LlamaAttention, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c5dd8cc815e2..49fd2c3ea6c2 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -473,6 +473,7 @@ def __init__(self, config: Starcoder2Config): ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.rotary_emb = Starcoder2RotaryEmbedding(config=config) + self.gradient_checkpointing = False self.embedding_dropout = config.embedding_dropout # Initialize weights and apply final processing From 607e928ed1167c226c8e00c2aa2017430a7967db Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 15:13:53 +0100 Subject: [PATCH 35/97] fix --- src/transformers/models/gemma/modeling_gemma.py | 17 ++--------------- .../models/gemma2/modeling_gemma2.py | 1 + 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ca17084da1eb..f557fbb74298 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -527,19 +527,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -614,8 +603,6 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0688b3756c46..8196a9ff68b7 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -517,6 +517,7 @@ def __init__(self, config: Gemma2Config): ) self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Gemma2RotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() From 461259527752424ab6f8d2ec6084ee81a7492f99 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 15:14:19 +0100 Subject: [PATCH 36/97] style --- src/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f7128ba7afc2..cddf68c1af6b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -489,7 +489,7 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - + # Initialize weights and apply final processing self.post_init() From 20c376cb7e6046b885e6c5ea29f732f2afe74091 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 16:49:45 +0100 Subject: [PATCH 37/97] make config necessary --- .../modular-transformers/modeling_dummy.py | 2 +- .../modeling_multimodal1.py | 2 +- .../modeling_my_new_model2.py | 18 +++--------------- .../modular-transformers/modeling_super.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- .../models/cohere2/modeling_cohere2.py | 1 + src/transformers/models/dbrx/modeling_dbrx.py | 2 +- .../models/falcon/modeling_falcon.py | 2 +- .../models/gemma/modeling_gemma.py | 14 ++++---------- src/transformers/models/gemma/modular_gemma.py | 12 +++--------- .../models/gpt_neox/modeling_gpt_neox.py | 2 +- .../modeling_gpt_neox_japanese.py | 2 +- .../models/granite/modeling_granite.py | 2 +- .../models/granitemoe/modeling_granitemoe.py | 2 +- .../models/jetmoe/modeling_jetmoe.py | 2 +- .../models/llama/modeling_llama.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 8 ++------ .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- .../models/moshi/modeling_moshi.py | 8 ++------ src/transformers/models/olmo/modeling_olmo.py | 2 +- .../models/olmo2/modeling_olmo2.py | 2 +- .../models/olmoe/modeling_olmoe.py | 2 +- .../models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- 30 files changed, 39 insertions(+), 70 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 7a049fa0b9e3..e80f6cc57d14 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -47,8 +47,8 @@ def extra_repr(self): class DummyRotaryEmbedding(nn.Module): def __init__( self, + config: DummyConfig, device=None, - config: Optional[DummyConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 6bbd2f3dcad2..f775153418f4 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -47,8 +47,8 @@ def extra_repr(self): class Multimodal1TextRotaryEmbedding(nn.Module): def __init__( self, + config: Multimodal1TextConfig, device=None, - config: Optional[Multimodal1TextConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 404b83288aa9..9753b3a22b96 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -453,6 +453,7 @@ def __init__(self, config: MyNewModel2Config): ) self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MyNewModel2RotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @@ -496,19 +497,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True # noqa: F841 - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 @@ -583,8 +573,6 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 0fd187536ea4..d1c1eea2b154 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -44,8 +44,8 @@ def extra_repr(self): class SuperRotaryEmbedding(nn.Module): def __init__( self, + config: SuperConfig, device=None, - config: Optional[SuperConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 31a6bcefb6f2..c74edf450fe2 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -708,8 +708,8 @@ def _init_weights(self, module): class AriaTextRotaryEmbedding(nn.Module): def __init__( self, + config: AriaTextConfig, device=None, - config: Optional[AriaTextConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index c328690efc8c..1ffa4bffddc3 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -660,6 +660,7 @@ def __init__(self, config: Cohere2Config): ) self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) self.rotary_emb = Cohere2RotaryEmbedding(config=config) + self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 9294fff11db5..9491343d45ac 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -51,8 +51,8 @@ class DbrxRotaryEmbedding(nn.Module): def __init__( self, + config: DbrxConfig, device=None, - config: Optional[DbrxConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e2f609bc348a..b31b47033c34 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -113,8 +113,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class FalconRotaryEmbedding(nn.Module): def __init__( self, + config: FalconConfig, device=None, - config: Optional[FalconConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f557fbb74298..f92e61a4782d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -323,8 +323,8 @@ def _init_weights(self, module): class GemmaRotaryEmbedding(nn.Module): def __init__( self, + config: GemmaConfig, device=None, - config: Optional[GemmaConfig] = None, ): super().__init__() self.rope_kwargs = {} @@ -590,9 +590,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -602,16 +599,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 7295a0cbd3c2..2662ebc97101 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -464,9 +464,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -476,16 +473,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() # Example where we ony modify the docstring and call super diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index c0559d0e4cbc..62f2bc8709c1 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -490,8 +490,8 @@ def __init__(self, config, layer_idx=None): class GPTNeoXRotaryEmbedding(nn.Module): def __init__( self, + config: GPTNeoXConfig, device=None, - config: Optional[GPTNeoXConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 498573f065d1..f7aad9a78542 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -227,8 +227,8 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): class GPTNeoXJapaneseRotaryEmbedding(nn.Module): def __init__( self, + config: GPTNeoXJapaneseConfig, device=None, - config: Optional[GPTNeoXJapaneseConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index acbefc13d7ce..e5f8577f144d 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -295,8 +295,8 @@ def forward( class GraniteRotaryEmbedding(nn.Module): def __init__( self, + config: GraniteConfig, device=None, - config: Optional[GraniteConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index ce36905ba23c..d64d8c852139 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -160,8 +160,8 @@ def extra_repr(self): class GraniteMoeRotaryEmbedding(nn.Module): def __init__( self, + config: GraniteMoeConfig, device=None, - config: Optional[GraniteMoeConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 499b15a382a5..94d82c0465f2 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -388,8 +388,8 @@ def extra_repr(self): class JetMoeRotaryEmbedding(nn.Module): def __init__( self, + config: JetMoeConfig, device=None, - config: Optional[JetMoeConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index cddf68c1af6b..1ef0700cc2ae 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -82,8 +82,8 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__( self, + config: LlamaConfig, device=None, - config: Optional[LlamaConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 68ee9770d5a7..913af9811738 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -367,8 +367,8 @@ def forward(self, x: torch.Tensor): class MimiRotaryEmbedding(nn.Module): def __init__( self, + config: MimiConfig, device=None, - config: Optional[MimiConfig] = None, ): super().__init__() self.rope_kwargs = {} @@ -531,11 +531,7 @@ def __init__(self, config: MimiConfig, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self.rotary_emb = MimiRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = MimiRotaryEmbedding(config) self.sliding_window = config.sliding_window # Ignore copy def forward( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4fa684b2d37e..e8fe41c9f21c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -81,8 +81,8 @@ def extra_repr(self): class MistralRotaryEmbedding(nn.Module): def __init__( self, + config: MistralConfig, device=None, - config: Optional[MistralConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ecac87a303c9..ce5b2cbf9284 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -380,8 +380,8 @@ def forward( class MixtralRotaryEmbedding(nn.Module): def __init__( self, + config: MixtralConfig, device=None, - config: Optional[MixtralConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 5fc97e3e2f73..9a059766837c 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -310,8 +310,8 @@ def forward(self, x, layer_idx=None): class MoshiRotaryEmbedding(nn.Module): def __init__( self, + config: MoshiConfig, device=None, - config: Optional[MoshiConfig] = None, ): super().__init__() self.rope_kwargs = {} @@ -493,11 +493,7 @@ def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_fle self.rotary_emb = None if use_rope: self.rope_theta = config.rope_theta - self.rotary_emb = MoshiRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = MoshiRotaryEmbedding(config) # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward # no longer copied after attention refactors diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2e74dd39099e..053d37ab3311 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -279,8 +279,8 @@ def forward( class OlmoRotaryEmbedding(nn.Module): def __init__( self, + config: OlmoConfig, device=None, - config: Optional[OlmoConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index a02afd91d752..fbc40e75a573 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -303,8 +303,8 @@ def _init_weights(self, module): class Olmo2RotaryEmbedding(nn.Module): def __init__( self, + config: Olmo2Config, device=None, - config: Optional[Olmo2Config] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 7ceaa3622b6e..5e6cbc0ee346 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -160,8 +160,8 @@ def extra_repr(self): class OlmoeRotaryEmbedding(nn.Module): def __init__( self, + config: OlmoeConfig, device=None, - config: Optional[OlmoeConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 765711711ca8..b170f85872fe 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -59,8 +59,8 @@ class PersimmonRotaryEmbedding(nn.Module): def __init__( self, + config: PersimmonConfig, device=None, - config: Optional[PersimmonConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 512f88d8b1dc..69f8d370037d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -286,8 +286,8 @@ def extra_repr(self): class PhiRotaryEmbedding(nn.Module): def __init__( self, + config: PhiConfig, device=None, - config: Optional[PhiConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 35f26971003b..e3dac859360e 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -79,8 +79,8 @@ def extra_repr(self): class Phi3RotaryEmbedding(nn.Module): def __init__( self, + config: Phi3Config, device=None, - config: Optional[Phi3Config] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8681bd83dc64..dc117000d521 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -279,8 +279,8 @@ def forward( class Qwen2RotaryEmbedding(nn.Module): def __init__( self, + config: Qwen2Config, device=None, - config: Optional[Qwen2Config] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a3c0d36e2a5a..7b24a29d852c 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -169,8 +169,8 @@ def extra_repr(self): class Qwen2MoeRotaryEmbedding(nn.Module): def __init__( self, + config: Qwen2MoeConfig, device=None, - config: Optional[Qwen2MoeConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 2abde43770f8..7674bafbdc6f 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -65,8 +65,8 @@ class StableLmRotaryEmbedding(nn.Module): def __init__( self, + config: StableLmConfig, device=None, - config: Optional[StableLmConfig] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 49fd2c3ea6c2..5c7ed83d66ca 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -62,8 +62,8 @@ class Starcoder2RotaryEmbedding(nn.Module): def __init__( self, + config: Starcoder2Config, device=None, - config: Optional[Starcoder2Config] = None, ): super().__init__() self.rope_kwargs = {} From 0ac9db2b025931e129a0766ff36ddca617bb9129 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 16:50:01 +0100 Subject: [PATCH 38/97] bis --- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8196a9ff68b7..5eed6a656e79 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -357,8 +357,8 @@ def _init_weights(self, module): class Gemma2RotaryEmbedding(nn.Module): def __init__( self, + config: Gemma2Config, device=None, - config: Optional[Gemma2Config] = None, ): super().__init__() self.rope_kwargs = {} diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 15662c360fc5..dfa43fc8ff30 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -78,8 +78,8 @@ def extra_repr(self): class GlmRotaryEmbedding(nn.Module): def __init__( self, + config: GlmConfig, device=None, - config: Optional[GlmConfig] = None, ): super().__init__() self.rope_kwargs = {} From 349b7ab8b7ea9e14e28ba16f3b6fac4da05865da Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 16:52:18 +0100 Subject: [PATCH 39/97] bis --- .../modular-transformers/modeling_my_new_model2.py | 14 ++++---------- src/transformers/models/gemma/modeling_gemma.py | 1 - src/transformers/models/gemma/modular_gemma.py | 1 - 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 9753b3a22b96..f770749d7f21 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -293,8 +293,8 @@ def _init_weights(self, module): class MyNewModel2RotaryEmbedding(nn.Module): def __init__( self, + config: MyNewModel2Config, device=None, - config: Optional[MyNewModel2Config] = None, ): super().__init__() self.rope_kwargs = {} @@ -560,9 +560,6 @@ def forward( hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -572,16 +569,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f92e61a4782d..c9fd0900a802 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -558,7 +558,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 2662ebc97101..548552980579 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -432,7 +432,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: From defa88fff48f80f333006fc476c5bdc31fc988df Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 16:55:38 +0100 Subject: [PATCH 40/97] Update modeling_my_new_model2.py --- examples/modular-transformers/modeling_my_new_model2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index f770749d7f21..c368bfe500eb 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -528,7 +528,6 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None for decoder_layer in self.layers[: self.config.num_hidden_layers]: if output_hidden_states: From fbf4b5523a03d514bcd034eeae09414d9035cc7c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 17:08:50 +0100 Subject: [PATCH 41/97] is_causal attr --- .../modular-transformers/modeling_dummy.py | 1 + .../modeling_multimodal1.py | 1 + .../modeling_my_new_model2.py | 1 + .../modular-transformers/modeling_super.py | 1 + .../integrations/flash_attention.py | 4 +++- src/transformers/models/aria/modeling_aria.py | 1 + .../models/gemma/modeling_gemma.py | 1 + .../models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 1 + .../models/granite/modeling_granite.py | 1 + .../models/llama/modeling_llama.py | 1 + .../models/mistral/modeling_mistral.py | 1 + .../models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + .../models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + .../models/qwen2/modeling_qwen2.py | 18 +++++------------- src/transformers/models/qwen2/modular_qwen2.py | 8 ++++++++ .../models/starcoder2/modeling_starcoder2.py | 1 + 19 files changed, 32 insertions(+), 15 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index e80f6cc57d14..c96529b07b16 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -198,6 +198,7 @@ def __init__(self, config: DummyConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index f775153418f4..d49471e6e92e 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -198,6 +198,7 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index c368bfe500eb..d4a07629bf80 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -133,6 +133,7 @@ def __init__(self, config: MyNewModel2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index d1c1eea2b154..06b5a52dc132 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -195,6 +195,7 @@ def __init__(self, config: SuperConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 333a22bbcf98..76bb041e0795 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -3,7 +3,9 @@ import torch from ..modeling_flash_attention_utils import _flash_attention_forward +from ..utils import is_flash_attn_greater_or_equal_2_10 +_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() def flash_attention_forward( module: torch.nn.Module, @@ -47,7 +49,7 @@ def flash_attention_forward( softmax_scale=scaling, sliding_window=sliding_window, softcap=softcap, - use_top_left_mask=module._flash_attn_uses_top_left_mask, + use_top_left_mask=_use_top_left_mask, **kwargs, ) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c74edf450fe2..4ec49f658043 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -503,6 +503,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c9fd0900a802..9f54f1d9d1cc 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -163,6 +163,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 5eed6a656e79..709e4f8cdaa9 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -174,6 +174,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias @@ -189,7 +190,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): ) self.attn_logit_softcapping = self.config.attn_logit_softcapping self.attention_dropout = self.config.attention_dropout - self.is_causal = True self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index dfa43fc8ff30..56e27db279ef 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -247,6 +247,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = 1 / math.sqrt(self.head_dim) + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index e5f8577f144d..34879634dcb1 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -156,6 +156,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.attention_multiplier + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 1ef0700cc2ae..21683893e295 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -233,6 +233,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e8fe41c9f21c..906dac7de0bc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -216,6 +216,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ce5b2cbf9284..2e96239906b8 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -242,6 +242,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 053d37ab3311..d264fa66962d 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -156,6 +156,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index fbc40e75a573..762375e979a1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -125,6 +125,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 69f8d370037d..4c6aafbd49fd 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -115,6 +115,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index dc117000d521..6925cc87d430 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -130,19 +130,11 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.is_causal = True + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 7519db0877d8..1b55cf31af35 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -36,6 +36,14 @@ def __init__(self, config): class Qwen2Attention(LlamaAttention): + + def __init__(self, config: Qwen2Config, layer_idx: int): + super().__init__(config, layer_idx) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5c7ed83d66ca..456f60e88596 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -214,6 +214,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.is_causal = True self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias From 9104d0a465c37813934180de3b2437417976f17f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 17:33:19 +0100 Subject: [PATCH 42/97] fix --- .../modeling_my_new_model2.py | 134 ++++++------ .../integrations/flash_attention.py | 2 + .../models/gemma/modeling_gemma.py | 151 ++++++------- .../models/gemma/modular_gemma.py | 116 +--------- src/transformers/models/glm/modeling_glm.py | 206 +++++++++--------- src/transformers/models/glm/modular_glm.py | 89 ++------ .../models/llama/modeling_llama.py | 5 +- .../models/qwen2/modular_qwen2.py | 1 - 8 files changed, 270 insertions(+), 434 deletions(-) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index d4a07629bf80..e5449316af30 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -44,6 +44,71 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class MyNewModel2RotaryEmbedding(nn.Module): + def __init__( + self, + config: MyNewModel2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class MyNewModel2MLP(nn.Module): def __init__(self, config): super().__init__() @@ -196,8 +261,8 @@ def __init__(self, config: MyNewModel2Config, layer_idx: int): self.self_attn = MyNewModel2Attention(config=config, layer_idx=layer_idx) self.mlp = MyNewModel2MLP(config) - self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size) - self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size) + self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -291,71 +356,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -class MyNewModel2RotaryEmbedding(nn.Module): - def __init__( - self, - config: MyNewModel2Config, - device=None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - MY_NEW_MODEL2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 76bb041e0795..ed2d06a98cf5 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -5,8 +5,10 @@ from ..modeling_flash_attention_utils import _flash_attention_forward from ..utils import is_flash_attn_greater_or_equal_2_10 + _use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + def flash_attention_forward( module: torch.nn.Module, query: torch.Tensor, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9f54f1d9d1cc..e16b50426d66 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -39,6 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -74,6 +75,71 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class GemmaRotaryEmbedding(nn.Module): + def __init__( + self, + config: GemmaConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class GemmaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -226,8 +292,8 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) self.mlp = GemmaMLP(config) - self.input_layernorm = GemmaRMSNorm(config.hidden_size) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -321,71 +387,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -class GemmaRotaryEmbedding(nn.Module): - def __init__( - self, - config: GemmaConfig, - device=None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - GEMMA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -729,6 +730,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} @@ -776,7 +780,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -797,16 +801,16 @@ def forward( ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + >>> model = GemmaForCausalLM.from_pretrained("meta-gemma/Gemma-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-gemma/Gemma-2-7b-hf") - >>> prompt = "What is your favorite condiment?" + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -826,6 +830,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -834,7 +839,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 548552980579..a1d594684614 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -22,17 +22,14 @@ from ...cache_utils import Cache, DynamicCache from ...configuration_utils import PretrainedConfig -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...modeling_outputs import BaseModelOutputWithPast from ...tokenization_utils import AddedToken, PreTrainedTokenizer from ...utils import logging from ..llama.modeling_llama import ( - LlamaDecoderLayer, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, LlamaModel, - LlamaPreTrainedModel, ) from ..llama.tokenization_llama import LlamaTokenizer @@ -346,29 +343,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" -ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) - - -class GemmaDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: GemmaConfig, layer_idx: int): - super().__init__(config, layer_idx) - self.input_layernorm = GemmaRMSNorm(config.hidden_size) - self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size) - - -class GemmaPreTrainedModel(LlamaPreTrainedModel): - pass - - class GemmaModel(LlamaModel): - def __init__(self, config: GemmaConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_init() - def forward( self, input_ids: torch.LongTensor = None, @@ -481,97 +456,16 @@ def forward( return output if return_dict else output.to_tuple() -# Example where we ony modify the docstring and call super class GemmaForCausalLM(LlamaForCausalLM): - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.post_init() - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + pass class GemmaForSequenceClassification(LlamaForSequenceClassification): - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.post_init() + pass class GemmaForTokenClassification(LlamaForTokenClassification): - def __init__(self, config): - super().__init__(config) - self.model = GemmaModel(config) - self.post_init() + pass __all__ = [ @@ -581,5 +475,5 @@ def __init__(self, config): "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification", - "GemmaPreTrainedModel", + "GemmaPreTrainedModel", # noqa: F822 ] diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 56e27db279ef..f2952bdd3448 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, List, Optional, Tuple, Union import torch @@ -40,6 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -55,91 +55,6 @@ _CONFIG_FOR_DOC = "GlmConfig" -class GlmRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - GlmRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class GlmRotaryEmbedding(nn.Module): - def __init__( - self, - config: GlmConfig, - device=None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class GlmMLP(nn.Module): def __init__(self, config): super().__init__() @@ -246,7 +161,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = 1 / math.sqrt(self.head_dim) + self.scaling = self.head_dim**-0.5 self.is_causal = True self.q_proj = nn.Linear( @@ -258,7 +173,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.o_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=False) def forward( self, @@ -300,8 +215,93 @@ def forward( return attn_output, attn_weights +class GlmRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GlmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GlmRotaryEmbedding(nn.Module): + def __init__( + self, + config: GlmConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class GlmDecoderLayer(nn.Module): - def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): + def __init__(self, config: GlmConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size @@ -500,11 +500,7 @@ def __init__(self, config: GlmConfig): [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = GlmRotaryEmbedding( - dim=int(config.head_dim * config.partial_rotary_factor), - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + self.rotary_emb = GlmRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -745,11 +741,14 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: GlmConfig): + def __init__(self, config): super().__init__(config) self.model = GlmModel(config) self.vocab_size = config.vocab_size @@ -792,7 +791,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -813,16 +812,16 @@ def forward( ```python >>> from transformers import AutoTokenizer, GlmForCausalLM - >>> model = GlmForCausalLM.from_pretrained("google/glm-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/glm-7b") + >>> model = GlmForCausalLM.from_pretrained("meta-glm/Glm-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-glm/Glm-2-7b-hf") - >>> prompt = "What is your favorite condiment?" + >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -842,6 +841,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -850,7 +850,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -881,7 +881,7 @@ def forward( GLM_START_DOCSTRING, ) class GlmForSequenceClassification(GlmPreTrainedModel): - def __init__(self, config: GlmConfig): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = GlmModel(config) @@ -977,7 +977,7 @@ def forward( GLM_START_DOCSTRING, ) class GlmForTokenClassification(GlmPreTrainedModel): - def __init__(self, config: GlmConfig): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = GlmModel(config) diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 42192f1b7d6b..1f1e023c5a9d 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Optional import torch @@ -21,24 +20,13 @@ import torch.utils.checkpoint from ...utils import logging -from ..gemma.modeling_gemma import ( - GemmaForCausalLM, - GemmaForSequenceClassification, - GemmaForTokenClassification, -) -from ..granite.modeling_granite import ( - GraniteAttention, -) from ..llama.modeling_llama import ( - LlamaDecoderLayer, - LlamaModel, - LlamaPreTrainedModel, -) -from ..phi3.modeling_phi3 import ( - Phi3MLP, - Phi3RMSNorm, - Phi3RotaryEmbedding, + LlamaAttention, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlamaForTokenClassification, ) +from ..phi3.modeling_phi3 import Phi3MLP from .configuration_glm import GlmConfig @@ -47,14 +35,6 @@ _CHECKPOINT_FOR_DOC = "THUDM/glm-4-9b" -class GlmRMSNorm(Phi3RMSNorm): - pass - - -class GlmRotaryEmbedding(Phi3RotaryEmbedding): - pass - - class GlmMLP(Phi3MLP): pass @@ -108,68 +88,27 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): return q_embed, k_embed -class GlmAttention(GraniteAttention): +class GlmAttention(LlamaAttention): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.scaling = 1 / math.sqrt(self.head_dim) - - -class GlmDecoderLayer(LlamaDecoderLayer): - def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): - super().__init__() - - self.mlp = GlmMLP(config) - self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.o_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=False) -class GlmPreTrainedModel(LlamaPreTrainedModel): +class GlmForCausalLM(LlamaForCausalLM): pass -class GlmModel(GlmPreTrainedModel, LlamaModel): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.layers = nn.ModuleList( - [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = GlmRotaryEmbedding( - dim=int(config.head_dim * config.partial_rotary_factor), - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - -class GlmForCausalLM(GemmaForCausalLM): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.model = GlmModel(config) - self.post_init() - - -class GlmForSequenceClassification(GemmaForSequenceClassification): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.model = GlmModel(config) - self.post_init() +class GlmForSequenceClassification(LlamaForSequenceClassification): + pass -class GlmForTokenClassification(GemmaForTokenClassification): - def __init__(self, config: GlmConfig): - super().__init__(config) - self.model = GlmModel(config) - self.post_init() +class GlmForTokenClassification(LlamaForTokenClassification): + pass __all__ = [ - "GlmPreTrainedModel", - "GlmModel", + "GlmPreTrainedModel", # noqa: F822 + "GlmModel", # noqa: F822 "GlmForCausalLM", "GlmForSequenceClassification", "GlmForTokenClassification", diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 21683893e295..7ddb5805bc80 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -316,7 +316,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -340,9 +340,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 1b55cf31af35..52407f4a5bf2 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -36,7 +36,6 @@ def __init__(self, config): class Qwen2Attention(LlamaAttention): - def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__(config, layer_idx) self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) From 0b093400d1bf37da323b5075924a5fcb3a78d4be Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 17:35:57 +0100 Subject: [PATCH 43/97] remove past kv return from decoder layer --- examples/modular-transformers/modeling_dummy.py | 5 +---- examples/modular-transformers/modeling_multimodal1.py | 5 +---- examples/modular-transformers/modeling_my_new_model2.py | 5 +---- examples/modular-transformers/modeling_super.py | 5 +---- src/transformers/models/aria/modeling_aria.py | 5 +---- src/transformers/models/gemma/modeling_gemma.py | 5 +---- src/transformers/models/glm/modeling_glm.py | 5 +---- src/transformers/models/mistral/modeling_mistral.py | 5 +---- src/transformers/models/olmo/modeling_olmo.py | 5 +---- src/transformers/models/qwen2/modeling_qwen2.py | 5 +---- src/transformers/models/starcoder2/modeling_starcoder2.py | 5 +---- 11 files changed, 11 insertions(+), 44 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index c96529b07b16..ccc85ad40b02 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -281,7 +281,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -305,9 +305,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index d49471e6e92e..6895a83f1b93 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -281,7 +281,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -305,9 +305,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index e5449316af30..ef5b32cc0c39 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -281,7 +281,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -305,9 +305,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 06b5a52dc132..8cca9b0bfe23 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -278,7 +278,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -302,9 +302,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 4ec49f658043..8d9a97d7be1e 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -597,7 +597,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -621,9 +621,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e16b50426d66..f804ec711111 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -312,7 +312,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -336,9 +336,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f2952bdd3448..0b9693d9b4dd 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -328,7 +328,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -352,9 +352,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 906dac7de0bc..305d55800729 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -299,7 +299,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -323,9 +323,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index d264fa66962d..fb021ec49d25 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -247,7 +247,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -271,9 +271,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6925cc87d430..8d7db61c452d 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -238,7 +238,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -262,9 +262,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 456f60e88596..16de76af37aa 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -301,7 +301,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -325,9 +325,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs From 46a0df794aa6840df80c3fc42f350324dd15e34f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 17:52:48 +0100 Subject: [PATCH 44/97] fix --- src/transformers/models/dbrx/modeling_dbrx.py | 6 +----- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/glm/modular_glm.py | 2 +- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 9491343d45ac..137efa80f9af 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -272,11 +272,7 @@ def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None): self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_emb = DbrxRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = DbrxRotaryEmbedding(config) def forward( self, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 0b9693d9b4dd..16eb158d6613 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -173,7 +173,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) - self.o_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, diff --git a/src/transformers/models/glm/modular_glm.py b/src/transformers/models/glm/modular_glm.py index 1f1e023c5a9d..ec07be10fb6a 100644 --- a/src/transformers/models/glm/modular_glm.py +++ b/src/transformers/models/glm/modular_glm.py @@ -91,7 +91,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class GlmAttention(LlamaAttention): def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.o_proj = nn.Linear(config.hidden_size, self.hidden_size, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) class GlmForCausalLM(LlamaForCausalLM): From aedd88a75cd7bc2d86f3e46503158f717bd2b7ef Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 17:59:32 +0100 Subject: [PATCH 45/97] default rope config --- examples/modular-transformers/modeling_dummy.py | 2 +- examples/modular-transformers/modeling_multimodal1.py | 2 +- examples/modular-transformers/modeling_my_new_model2.py | 2 +- examples/modular-transformers/modeling_super.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/granite/modeling_granite.py | 2 +- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/moshi/modeling_moshi.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 30 files changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index ccc85ad40b02..b1664c93b50d 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -53,7 +53,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 6895a83f1b93..61afedf402fd 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -53,7 +53,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index ef5b32cc0c39..c6038c2f26ac 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -53,7 +53,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 8cca9b0bfe23..3dac56c4ce3d 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -50,7 +50,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 8d9a97d7be1e..71335a69d604 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -712,7 +712,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 137efa80f9af..6b21e434e755 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -57,7 +57,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b31b47033c34..b710c7741303 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -119,7 +119,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f804ec711111..08cdd6954f16 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -84,7 +84,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 709e4f8cdaa9..9d3504429dff 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -363,7 +363,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 16eb158d6613..3066cfddeaa5 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -244,7 +244,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 62f2bc8709c1..e7e54f546dda 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -496,7 +496,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index f7aad9a78542..fb0aeacb4ad3 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -233,7 +233,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 34879634dcb1..666179df29b7 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -302,7 +302,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index d64d8c852139..669963c24f66 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -166,7 +166,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 94d82c0465f2..bd0324a6deb2 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -394,7 +394,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7ddb5805bc80..c0e6d5c4160a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -88,7 +88,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 913af9811738..1b5baadbe4b0 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -373,7 +373,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 305d55800729..9f4ecc0edafd 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -87,7 +87,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2e96239906b8..9af11d1b3415 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -387,7 +387,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 9a059766837c..513559131f3f 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -316,7 +316,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index fb021ec49d25..f85797419d5e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -283,7 +283,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 762375e979a1..d614e408aed8 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -310,7 +310,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 5e6cbc0ee346..c667cf577f39 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -166,7 +166,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index b170f85872fe..9438a09b27a0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -65,7 +65,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4c6aafbd49fd..94d89695583f 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -293,7 +293,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e3dac859360e..a89dc5202d8b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -85,7 +85,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8d7db61c452d..0637e1a7cddd 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -274,7 +274,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7b24a29d852c..7e7a1f106152 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -175,7 +175,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 7674bafbdc6f..be3c0a1e5bf3 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -71,7 +71,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 16de76af37aa..237d277f2dac 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -68,7 +68,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: + if config.get("rope_scaling", None) is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" From 57e9b49f04dc8576769927f0c65ca8bf3b878113 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:05:43 +0100 Subject: [PATCH 46/97] correctly fix rope config --- examples/modular-transformers/modeling_dummy.py | 2 +- examples/modular-transformers/modeling_multimodal1.py | 2 +- examples/modular-transformers/modeling_my_new_model2.py | 2 +- examples/modular-transformers/modeling_super.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- src/transformers/models/falcon/modeling_falcon.py | 2 +- src/transformers/models/gemma/modeling_gemma.py | 2 +- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/gpt_neox/modeling_gpt_neox.py | 2 +- .../models/gpt_neox_japanese/modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/granite/modeling_granite.py | 2 +- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- src/transformers/models/jetmoe/modeling_jetmoe.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mimi/modeling_mimi.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/moshi/modeling_moshi.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 30 files changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b1664c93b50d..82ca2d349e69 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -53,7 +53,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 61afedf402fd..068999259f96 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -53,7 +53,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index c6038c2f26ac..a2c2ef584fea 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -53,7 +53,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 3dac56c4ce3d..09dc2351f5c3 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -50,7 +50,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 71335a69d604..3c60320fabb9 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -712,7 +712,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 6b21e434e755..d1fc36873e62 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -57,7 +57,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index b710c7741303..8d5a224f4f66 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -119,7 +119,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 08cdd6954f16..789cac2937ce 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -84,7 +84,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 9d3504429dff..46d799e9e9ba 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -363,7 +363,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 3066cfddeaa5..908b27fc9311 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -244,7 +244,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e7e54f546dda..7152d72f5b7f 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -496,7 +496,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index fb0aeacb4ad3..71602f01e7d6 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -233,7 +233,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 666179df29b7..4f0b63b187a1 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -302,7 +302,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 669963c24f66..1c4c06bbc8d7 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -166,7 +166,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bd0324a6deb2..04b1cc47e6dd 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -394,7 +394,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c0e6d5c4160a..307a12cb4ae7 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -88,7 +88,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 1b5baadbe4b0..1440ce1e075c 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -373,7 +373,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 9f4ecc0edafd..2a86059ca96d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -87,7 +87,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9af11d1b3415..ef48d1c3e869 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -387,7 +387,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 513559131f3f..f0281f57cf1c 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -316,7 +316,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index f85797419d5e..a0710b6b70b9 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -283,7 +283,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index d614e408aed8..9cce883d66a4 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -310,7 +310,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index c667cf577f39..fa3c2f3cd4d1 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -166,7 +166,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 9438a09b27a0..8d3c20b9ace7 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -65,7 +65,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 94d89695583f..8904413dc94b 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -293,7 +293,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index a89dc5202d8b..de6326850a41 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -85,7 +85,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 0637e1a7cddd..26451b964875 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -274,7 +274,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 7e7a1f106152..66d6ce2cd8d6 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -175,7 +175,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index be3c0a1e5bf3..88dc437cdcb9 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -71,7 +71,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 237d277f2dac..593dde952ec9 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -68,7 +68,7 @@ def __init__( super().__init__() self.rope_kwargs = {} # BC: "rope_type" was originally "type" - if config.get("rope_scaling", None) is not None: + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" From fe90ec05707f2d94b01197f34eeafee4dbdce2bc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:19:24 +0100 Subject: [PATCH 47/97] fix bias --- .../models/gemma/modeling_gemma.py | 32 +++++++++---------- .../models/gemma/modular_gemma.py | 9 ++++++ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 789cac2937ce..808e971b8c7d 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -75,6 +75,22 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class GemmaRotaryEmbedding(nn.Module): def __init__( self, @@ -140,22 +156,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class GemmaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index a1d594684614..2c4c26374204 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -29,6 +29,7 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, + LlamaMLP, LlamaModel, ) from ..llama.tokenization_llama import LlamaTokenizer @@ -343,6 +344,14 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class GemmaMLP(LlamaMLP): + def __init__(self, config): + super().__init__() + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + class GemmaModel(LlamaModel): def forward( self, From a3f50d0f7e44121bbe4022bac8a54e27e62288bd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:20:29 +0100 Subject: [PATCH 48/97] fix gpt2 attention output --- .../decision_transformer/modeling_decision_transformer.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 1bc96f19bb39..e80926ef8222 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -399,7 +399,7 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index e37a8100760e..32ff0458937e 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -409,7 +409,7 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) From 6a92c706b71396880da2c40b1319aff399ae660d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:27:26 +0100 Subject: [PATCH 49/97] fix test --- tests/test_modeling_common.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 13eacc4a5965..bc03855846fe 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -3876,20 +3876,6 @@ def test_sdpa_can_dispatch_non_composite_models(self): model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): """ From a28ad195f845f87d9f62662dc92d53c93fb3b777 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:40:27 +0100 Subject: [PATCH 50/97] fix inits --- .../modeling_my_new_model2.py | 32 +++++++++---------- .../models/jetmoe/modeling_jetmoe.py | 6 +--- .../models/olmo2/modeling_olmo2.py | 4 +-- .../models/olmo2/modular_olmo2.py | 4 +-- src/transformers/models/phi3/modeling_phi3.py | 6 +--- 5 files changed, 22 insertions(+), 30 deletions(-) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index a2c2ef584fea..325d2ba801a4 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -44,6 +44,22 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.eps}" +class MyNewModel2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class MyNewModel2RotaryEmbedding(nn.Module): def __init__( self, @@ -109,22 +125,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class MyNewModel2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 04b1cc47e6dd..7b7fd5a90d69 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -523,11 +523,7 @@ def __init__(self, config: JetMoeConfig, layer_idx: Optional[int] = None): self.kv_proj = torch.nn.Linear(config.hidden_size, self.kv_projection_size * 2, bias=False) - self.rotary_emb = JetMoeRotaryEmbedding( - config.kv_channels, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ) + self.rotary_emb = JetMoeRotaryEmbedding(config) def forward( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 9cce883d66a4..f9d101573dc8 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -139,8 +139,8 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) def forward( self, diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index a169f72c8411..8e6d1923766b 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -167,8 +167,8 @@ class Olmo2RMSNorm(LlamaRMSNorm): class Olmo2Attention(OlmoAttention): def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx=layer_idx) - self.q_norm = Olmo2RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) - self.k_norm = Olmo2RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) def forward( self, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index de6326850a41..b3b8c6426bce 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -375,11 +375,7 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): def _init_rope(self): if self.rope_scaling is None: - self.rotary_emb = Phi3RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) + self.rotary_emb = Phi3RotaryEmbedding(self.config) else: scaling_type = self.config.rope_scaling["type"] if scaling_type == "longrope": From 9bd6c9482e0d7ab488693d840ad7ed0596ffd99a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 18:56:50 +0100 Subject: [PATCH 51/97] fix default sdpa --- src/transformers/modeling_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c502742429e7..ee576f69cfe4 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1531,10 +1531,10 @@ def _autoset_attn_implementation( config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. - # config = cls._check_and_enable_sdpa( - # config, - # hard_check_only=False if requested_attn_implementation is None else True, - # ) + config = cls._check_and_enable_sdpa( + config, + hard_check_only=False if requested_attn_implementation is None else True, + ) if ( torch.version.hip is not None From fae05e166cba00a9fd0158891c8e3e214d4bf969 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 19:07:07 +0100 Subject: [PATCH 52/97] fix default sdpa implementation --- examples/modular-transformers/modeling_dummy.py | 8 +++++++- .../modular-transformers/modeling_multimodal1.py | 8 +++++++- .../modular-transformers/modeling_my_new_model2.py | 8 +++++++- examples/modular-transformers/modeling_super.py | 13 +++++++++++-- src/transformers/models/aria/modeling_aria.py | 8 +++++++- src/transformers/models/gemma/modeling_gemma.py | 8 +++++++- src/transformers/models/glm/modeling_glm.py | 8 +++++++- src/transformers/models/granite/modeling_granite.py | 8 +++++++- src/transformers/models/llama/modeling_llama.py | 8 +++++++- src/transformers/models/mistral/modeling_mistral.py | 8 +++++++- src/transformers/models/mixtral/modeling_mixtral.py | 8 +++++++- 11 files changed, 81 insertions(+), 12 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 82ca2d349e69..d8954effac93 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -238,7 +238,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 068999259f96..a71e8640475f 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -238,7 +238,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 325d2ba801a4..c08dc066f583 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -238,7 +238,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 09dc2351f5c3..fb2b4d9d0350 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -17,10 +17,13 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_super import SuperConfig +logger = logging.get_logger(__name__) + + class SuperRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -235,7 +238,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 3c60320fabb9..e5f754739ea6 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -543,7 +543,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 808e971b8c7d..0da455d1a66b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -269,7 +269,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 908b27fc9311..812d8d66c759 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -200,7 +200,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 4f0b63b187a1..a2b99f6b3f9a 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -196,7 +196,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 307a12cb4ae7..88e66074160c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -273,7 +273,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2a86059ca96d..90d47bbdb33e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -256,7 +256,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index ef48d1c3e869..fc77d8ca4a4b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -282,7 +282,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, From 838d211d81faf1806e4df9786f25402ca1508bb5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 19:24:29 +0100 Subject: [PATCH 53/97] harmonize classes --- .../modeling_decision_transformer.py | 8 +- .../models/gemma2/modeling_gemma2.py | 8 +- .../models/gemma2/modular_gemma2.py | 8 +- src/transformers/models/gpt2/modeling_gpt2.py | 8 +- .../models/mistral/modeling_mistral.py | 187 +++++++++--------- .../models/mistral/modular_mistral.py | 10 + .../models/mixtral/modeling_mixtral.py | 17 +- src/transformers/models/olmo/modeling_olmo.py | 8 +- src/transformers/models/olmo/modular_olmo.py | 12 +- .../models/olmo2/modeling_olmo2.py | 8 +- .../models/olmo2/modular_olmo2.py | 8 +- src/transformers/models/phi/modeling_phi.py | 8 +- src/transformers/models/phi/modular_phi.py | 12 +- .../models/qwen2/modeling_qwen2.py | 8 +- .../models/qwen2/modular_qwen2.py | 8 +- .../models/starcoder2/modeling_starcoder2.py | 25 ++- .../models/starcoder2/modular_starcoder2.py | 12 +- 17 files changed, 217 insertions(+), 138 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index e80926ef8222..46f539430e73 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -296,7 +296,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 46d799e9e9ba..444ee33a488e 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -217,7 +217,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index b166ea72957e..c184d1713372 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -263,7 +263,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 32ff0458937e..45ca2fcbded7 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -311,7 +311,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 90d47bbdb33e..7bebfdd20e9b 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -58,91 +58,6 @@ def forward(self, x): return down_proj -class MistralRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - MistralRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class MistralRotaryEmbedding(nn.Module): - def __init__( - self, - config: MistralConfig, - device=None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -217,19 +132,10 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, @@ -277,6 +183,91 @@ def forward( return attn_output, attn_weights +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MistralRotaryEmbedding(nn.Module): + def __init__( + self, + config: MistralConfig, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + class MistralDecoderLayer(nn.Module): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__() diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 6e6a18f9ef65..ebec129c1ac2 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -5,6 +5,7 @@ from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ..llama.modeling_llama import ( + LlamaAttention, LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, @@ -26,6 +27,15 @@ def __init__(self, config): self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) +class MistralAttention(LlamaAttention): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + + class MistralModel(LlamaModel): def _update_causal_mask( self, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index fc77d8ca4a4b..c4caabf7ba1b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -243,19 +243,10 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) def forward( self, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index a0710b6b70b9..6a74776c29ae 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -205,7 +205,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 72fc011e44f8..21d1c496603c 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -7,6 +7,7 @@ from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import logging from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -19,6 +20,9 @@ from .configuration_olmo import OlmoConfig +logger = logging.get_logger(__name__) + + class OlmoLayerNorm(nn.Module): """LayerNorm but with no learnable weight or bias.""" @@ -80,7 +84,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index f9d101573dc8..8fd787f41565 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -174,7 +174,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 8e6d1923766b..2b086094f4fa 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -202,7 +202,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8904413dc94b..29b233be1e3e 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -183,7 +183,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 051b9af61f9e..99174412e800 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -6,6 +6,7 @@ from transformers.cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...utils import logging from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -18,6 +19,9 @@ from .configuration_phi import PhiConfig +logger = logging.get_logger(__name__) + + class PhiAttention(LlamaAttention): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__(config, layer_idx) @@ -75,7 +79,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 26451b964875..83cf959c7a3a 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -169,7 +169,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 52407f4a5bf2..b1ec41677f9c 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -76,7 +76,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 593dde952ec9..d04b8d2f3ed8 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -215,19 +215,10 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout @@ -256,7 +247,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index d4163b9f478b..0ba10dfc2862 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -104,6 +104,10 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): super().__init__() self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) + self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) def forward( self, @@ -130,7 +134,13 @@ def forward( attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, From e0d10f653e56f931aab4f749735a340cb449d6c9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 19:38:52 +0100 Subject: [PATCH 54/97] fix mistral --- examples/modular-transformers/modeling_dummy.py | 4 ++-- examples/modular-transformers/modeling_multimodal1.py | 4 ++-- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/granite/modeling_granite.py | 2 +- src/transformers/models/llama/modeling_llama.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 5 ++--- src/transformers/models/mistral/modular_mistral.py | 3 +-- src/transformers/models/mixtral/modeling_mixtral.py | 3 +-- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- 13 files changed, 16 insertions(+), 19 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index d8954effac93..13d62338e1f9 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_dummy.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -474,7 +474,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index a71e8640475f..58ca91fd1ae6 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -4,7 +4,7 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_multimodal1.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -474,7 +474,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index e5f754739ea6..434dbd6f1639 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -889,7 +889,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 812d8d66c759..170d1edd7344 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -521,7 +521,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index a2b99f6b3f9a..0765f6529170 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -524,7 +524,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 88e66074160c..43465277418b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -509,7 +509,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 7bebfdd20e9b..4f8e2c655412 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -483,7 +483,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -590,11 +590,10 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: + if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index ebec129c1ac2..f4827202fc34 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -43,11 +43,10 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: + if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c4caabf7ba1b..f68d3675a242 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -736,11 +736,10 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - use_cache: bool, output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and use_cache: + if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6a74776c29ae..fc5dae5f0195 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -505,7 +505,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 8fd787f41565..fccd695b06b1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -487,7 +487,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 29b233be1e3e..bd42588b1ddf 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -515,7 +515,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 83cf959c7a3a..346d00a96ec1 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -496,7 +496,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, From b275fdc8d7f185a92679d690d57520845fdc3625 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 20:09:07 +0100 Subject: [PATCH 55/97] fix sliding window models --- .../modular-transformers/modeling_dummy.py | 29 ++- .../modeling_multimodal1.py | 29 ++- .../modeling_my_new_model2.py | 29 ++- .../modular-transformers/modeling_super.py | 29 ++- src/transformers/models/aria/modeling_aria.py | 29 ++- .../models/gemma/modeling_gemma.py | 29 ++- .../models/gemma2/modeling_gemma2.py | 1 + src/transformers/models/glm/modeling_glm.py | 29 ++- .../models/granite/modeling_granite.py | 29 ++- .../models/llama/modeling_llama.py | 29 ++- .../models/mistral/modeling_mistral.py | 29 ++- .../models/mistral/modular_mistral.py | 58 +++++ .../models/mixtral/modeling_mixtral.py | 28 ++- src/transformers/models/olmo/modeling_olmo.py | 29 ++- src/transformers/models/olmo/modular_olmo.py | 2 + .../models/olmo2/modeling_olmo2.py | 29 ++- .../models/olmo2/modular_olmo2.py | 2 + src/transformers/models/phi/modeling_phi.py | 29 ++- src/transformers/models/phi/modular_phi.py | 2 + .../models/qwen2/modeling_qwen2.py | 29 ++- .../models/qwen2/modular_qwen2.py | 2 + .../models/starcoder2/modeling_starcoder2.py | 212 +++++++++++------- .../models/starcoder2/modular_starcoder2.py | 58 ++--- 23 files changed, 539 insertions(+), 232 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 13d62338e1f9..0e73f298c486 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -198,6 +210,7 @@ def __init__(self, config: DummyConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -251,6 +264,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 58ca91fd1ae6..9ee097e950f0 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -198,6 +210,7 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -251,6 +264,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index c08dc066f583..4e88c5aa1d13 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -198,6 +210,7 @@ def __init__(self, config: MyNewModel2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -251,6 +264,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index fb2b4d9d0350..530c25721d2a 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -171,20 +171,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -198,6 +210,7 @@ def __init__(self, config: SuperConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -251,6 +264,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 434dbd6f1639..232d169595fe 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -476,20 +476,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -503,6 +515,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -556,6 +569,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 0da455d1a66b..5e1bc5bcfac1 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -202,20 +202,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -229,6 +241,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -282,6 +295,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 444ee33a488e..ff02bc8c06b4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -174,6 +174,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 170d1edd7344..b1c85f5177ad 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -86,20 +86,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -162,6 +174,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -213,6 +226,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 0765f6529170..5bcde06c46f2 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -129,20 +129,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -156,6 +168,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.attention_multiplier + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -209,6 +222,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 43465277418b..17691d72452c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -206,20 +206,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -233,6 +245,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -286,6 +299,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4f8e2c655412..05c7bac26186 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -37,7 +37,6 @@ logger = logging.get_logger(__name__) - _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" _CONFIG_FOR_DOC = "MistralConfig" @@ -104,20 +103,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -131,6 +142,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) @@ -175,6 +187,9 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index f4827202fc34..a200d4808505 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -1,9 +1,15 @@ +from typing import Callable, Optional, Tuple + import torch import torch.utils.checkpoint from torch import nn from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack +from ...utils import logging from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -12,10 +18,14 @@ LlamaForTokenClassification, LlamaMLP, LlamaModel, + apply_rotary_pos_emb, + eager_attention_forward, ) from .configuration_mistral import MistralConfig +logger = logging.get_logger(__name__) + _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1" @@ -35,6 +45,54 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + class MistralModel(LlamaModel): def _update_causal_mask( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index f68d3675a242..43c1085b5108 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -215,20 +215,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -242,6 +254,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) @@ -286,6 +299,9 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index fc5dae5f0195..c518bc3a0fa6 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -129,20 +129,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -156,6 +168,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -218,6 +231,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 21d1c496603c..6cd0f0d21ee8 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -97,6 +97,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index fccd695b06b1..d4a38f20ad1c 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -98,20 +98,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -125,6 +137,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -187,6 +200,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 2b086094f4fa..74873d57b409 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -215,6 +215,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index bd42588b1ddf..0e409c0c953f 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -88,20 +88,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -115,6 +127,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -196,6 +209,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 99174412e800..b3f273298643 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -92,6 +92,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 346d00a96ec1..8dd0659942b5 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -103,20 +103,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -130,6 +142,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) @@ -182,6 +195,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, sliding_window=sliding_window, **kwargs, ) diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index b1ec41677f9c..47a785a27b76 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -89,6 +89,8 @@ def forward( query_states, key_states, value_states, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, sliding_window=sliding_window, **kwargs, ) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d04b8d2f3ed8..3c76732e0ee6 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -30,7 +30,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -59,71 +59,6 @@ _CONFIG_FOR_DOC = "Starcoder2Config" -class Starcoder2RotaryEmbedding(nn.Module): - def __init__( - self, - config: Starcoder2Config, - device=None, - ): - super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config): super().__init__() @@ -187,20 +122,32 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + scaling: Optional[float] = None, + **kwargs, +): + if scaling is None: + scaling = module.head_dim**-0.5 - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights @@ -214,12 +161,12 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 + self.attention_droupout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) - self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout def forward( @@ -261,6 +208,7 @@ def forward( key_states, value_states, dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) @@ -325,6 +273,71 @@ def forward( return outputs +class Starcoder2RotaryEmbedding(nn.Module): + def __init__( + self, + config: Starcoder2Config, + device=None, + ): + super().__init__() + self.rope_kwargs = {} + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + STARCODER2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads @@ -584,6 +597,14 @@ def _update_causal_mask( output_attentions: bool, ): if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and past_key_values is not None: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -593,21 +614,30 @@ def _update_causal_mask( # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if using_static_cache: + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] @@ -624,6 +654,8 @@ def _update_causal_mask( device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, ) if ( @@ -635,7 +667,6 @@ def _update_causal_mask( # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -649,7 +680,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position( device: torch.device, cache_position: torch.Tensor, batch_size: int, - **kwargs, + config: Starcoder2Config, + past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape @@ -657,13 +689,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Args: attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -672,6 +702,10 @@ def _prepare_4d_causal_attention_mask_with_cache_position( Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. + config (`Starcoder2Config`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. @@ -681,19 +715,27 @@ def _prepare_4d_causal_attention_mask_with_cache_position( causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) - return causal_mask diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 0ba10dfc2862..cb1122f4dbc4 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -38,16 +38,16 @@ is_flash_attn_2_available, logging, ) -from ..llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForSequenceClassification, - LlamaForTokenClassification, - LlamaRotaryEmbedding, +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralForCausalLM, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, apply_rotary_pos_emb, - repeat_kv, + eager_attention_forward, ) -from ..qwen2.modeling_qwen2 import Qwen2ForCausalLM, Qwen2Model from .configuration_starcoder2 import Starcoder2Config @@ -61,10 +61,6 @@ _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b" -class Starcoder2RotaryEmbedding(LlamaRotaryEmbedding): - pass - - class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config): super().__init__() @@ -82,27 +78,9 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -def eager_attention_forward(attention_class: nn.Module, query, key, value, attention_mask=None, **_kwargs): - config = attention_class.config - key_states = repeat_kv(key, attention_class.num_key_value_groups) - value_states = repeat_kv(value, attention_class.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * attention_class.scaling - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=config.attention_dropout, training=attention_class.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, attn_weights - - -class Starcoder2Attention(LlamaAttention): +class Starcoder2Attention(MistralAttention): def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): super().__init__() - self.attention_dropout = config.attention_dropout self.residual_dropout = config.residual_dropout self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) @@ -148,6 +126,7 @@ def forward( key_states, value_states, dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, **kwargs, ) @@ -157,7 +136,7 @@ def forward( return attn_output, attn_weights -class Starcoder2DecoderLayer(LlamaDecoderLayer): +class Starcoder2DecoderLayer(MistralDecoderLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__(self) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -167,14 +146,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): STARCODER2_INPUTS_DOCSTRING = None # will be automatically redefined -class Starcoder2Model(Qwen2Model): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Starcoder2DecoderLayer`] - - Args: - config: Starcoder2Config - """ - +class Starcoder2Model(MistralModel): def __init__(self, config: Starcoder2Config): super().__init__(config) self.embedding_dropout = config.embedding_dropout @@ -276,15 +248,15 @@ def forward( return output if return_dict else output.to_tuple() -class Starcoder2ForCausalLM(Qwen2ForCausalLM): +class Starcoder2ForCausalLM(MistralForCausalLM): pass -class Starcoder2ForSequenceClassification(LlamaForSequenceClassification): +class Starcoder2ForSequenceClassification(MistralForSequenceClassification): pass -class Starcoder2ForTokenClassification(LlamaForTokenClassification): +class Starcoder2ForTokenClassification(MistralForTokenClassification): pass From 71eb6a2d017d859e81b71af4898ccdd90f252620 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 20:36:13 +0100 Subject: [PATCH 56/97] mixtral --- .../models/mixtral/modeling_mixtral.py | 60 +++++------- .../models/mixtral/modular_mixtral.py | 94 ++++++++----------- 2 files changed, 64 insertions(+), 90 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 43c1085b5108..27226dc7417a 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -48,6 +48,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( + LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, @@ -331,7 +332,8 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -360,14 +362,16 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -382,9 +386,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -624,6 +625,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -646,19 +648,8 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -677,11 +668,13 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -698,6 +691,7 @@ def forward( output_router_logits, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -709,13 +703,12 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -728,23 +721,14 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( + output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) + return output if return_dict else output.to_tuple() def _update_causal_mask( self, @@ -897,6 +881,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + def load_balancing_loss_func( gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], num_experts: Optional[int] = None, @@ -1030,7 +1017,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1086,6 +1073,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -1094,7 +1082,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index a925998a5228..2c209bcd5455 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -27,22 +27,26 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, ) +from ...processing_utils import Unpack from ...utils import ( + LossKwargs, logging, ) -from ..llama.modeling_llama import ( - LlamaForCausalLM, - LlamaForQuestionAnswering, - LlamaForSequenceClassification, - LlamaForTokenClassification, - LlamaRMSNorm, +from ..mistral.modeling_mistral import ( + MistralAttention, + MistralForCausalLM, + MistralForQuestionAnswering, + MistralForSequenceClassification, + MistralForTokenClassification, + MistralModel, + MistralRMSNorm, ) -from ..mistral.modeling_mistral import MistralAttention, MistralModel from .configuration_mixtral import MixtralConfig @@ -220,7 +224,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -class MixtralRMSNorm(LlamaRMSNorm): +class MixtralRMSNorm(MistralRMSNorm): pass @@ -249,7 +253,8 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -278,14 +283,16 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -300,9 +307,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - if output_router_logits: outputs += (router_logits,) @@ -323,6 +327,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -345,19 +350,8 @@ def forward( ) use_cache = False - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) + if use_cache and past_key_values is None: + past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -376,11 +370,13 @@ def forward( hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None - next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: @@ -397,6 +393,7 @@ def forward( output_router_logits, use_cache, cache_position, + position_embeddings, ) else: layer_outputs = decoder_layer( @@ -408,13 +405,12 @@ def forward( output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -427,38 +423,27 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None - ) - return MoeModelOutputWithPast( + output = MoeModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) + return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... -class MixtralForCausalLM(LlamaForCausalLM): +class MixtralForCausalLM(MistralForCausalLM): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) - self.model = MixtralModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok - # Initialize weights and apply final processing - self.post_init() def forward( self, @@ -475,7 +460,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - **loss_kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -531,6 +516,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, + **kwargs, ) hidden_states = outputs[0] @@ -539,7 +525,7 @@ def forward( loss = None if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) aux_loss = None if output_router_logits: @@ -569,13 +555,13 @@ def forward( ) -class MixtralForSequenceClassification(LlamaForSequenceClassification): +class MixtralForSequenceClassification(MistralForSequenceClassification): pass -class MixtralForTokenClassification(LlamaForTokenClassification): +class MixtralForTokenClassification(MistralForTokenClassification): pass -class MixtralForQuestionAnswering(LlamaForQuestionAnswering): +class MixtralForQuestionAnswering(MistralForQuestionAnswering): pass From 4e25753420decbd67f7a18330a05bdb3b610b97b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 20:59:48 +0100 Subject: [PATCH 57/97] be more explicit --- .../modular-transformers/modeling_dummy.py | 4 +- .../modeling_multimodal1.py | 4 +- .../modeling_my_new_model2.py | 4 +- .../modular-transformers/modeling_super.py | 4 +- .../integrations/flash_attention.py | 6 +- .../integrations/flex_attention.py | 6 +- .../integrations/sdpa_attention.py | 6 +- src/transformers/models/aria/modeling_aria.py | 4 +- .../models/gemma/modeling_gemma.py | 4 +- .../models/gemma2/modeling_gemma2.py | 24 ++-- .../models/gemma2/modular_gemma2.py | 24 ++-- src/transformers/models/glm/modeling_glm.py | 4 +- .../models/granite/modeling_granite.py | 4 +- .../models/llama/modeling_llama.py | 4 +- .../models/mistral/modeling_mistral.py | 4 +- .../models/mistral/modular_mistral.py | 2 + .../models/mixtral/modeling_mixtral.py | 4 +- src/transformers/models/olmo/modeling_olmo.py | 9 +- src/transformers/models/olmo/modular_olmo.py | 10 +- .../models/olmo2/modeling_olmo2.py | 111 +++++++++--------- .../models/olmo2/modular_olmo2.py | 27 ++--- src/transformers/models/phi/modeling_phi.py | 4 +- src/transformers/models/phi/modular_phi.py | 2 + .../models/qwen2/modeling_qwen2.py | 4 +- .../models/qwen2/modular_qwen2.py | 2 + .../models/starcoder2/modeling_starcoder2.py | 4 +- .../models/starcoder2/modular_starcoder2.py | 2 + 27 files changed, 166 insertions(+), 121 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 0e73f298c486..c672788aee54 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 9ee097e950f0..39b221b0a603 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 4e88c5aa1d13..0324e192d97f 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 530c25721d2a..a43dd0d4948e 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -176,7 +176,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -230,6 +230,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -264,6 +265,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index ed2d06a98cf5..b33286627da9 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -14,14 +14,14 @@ def flash_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, target_dtype: torch.dtype = torch.float16, **kwargs, -): +) -> Tuple[torch.Tensor, None]: if attention_mask is not None: seq_len = attention_mask.shape[1] query = query[:, :, :seq_len] diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index dd4287921d2c..eacfb2b568b5 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -14,11 +14,11 @@ def flex_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], scaling: Optional[float] = None, softcap: Optional[float] = None, **kwargs, -): +) -> Tuple[torch.Tensor, torch.Tensor]: causal_mask = attention_mask if causal_mask is not None: causal_mask = causal_mask[:, :, :, : key.shape[-2]] diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 3a90ef9af282..805b379d2e18 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple import torch @@ -20,11 +20,11 @@ def sdpa_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, -): +) -> Tuple[torch.Tensor, None]: key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 232d169595fe..743367d7e520 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -481,7 +481,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -535,6 +535,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -569,6 +570,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 5e1bc5bcfac1..4d6c5698a9c3 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -207,7 +207,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -261,6 +261,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -295,6 +296,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index ff02bc8c06b4..62d82ee92cc8 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -140,25 +140,31 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if module.attn_logit_softcapping is not None: - attn_weights = attn_weights / module.attn_logit_softcapping + if softcap is not None: + attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * module.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -197,6 +203,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -231,6 +238,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c184d1713372..0cb9811e0e57 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -205,25 +205,31 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - mask: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if scaling is None: + scaling = module.head_dim**-0.5 + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * module.scaling + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if module.attn_logit_softcapping is not None: - attn_weights = attn_weights / module.attn_logit_softcapping + if softcap is not None: + attn_weights = attn_weights / softcap attn_weights = torch.tanh(attn_weights) - attn_weights = attn_weights * module.attn_logit_softcapping - if mask is not None: # no matter the length, we just slice it - causal_mask = mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights * softcap + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights @@ -242,6 +248,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -276,6 +283,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=self.attention_dropout if self.training else 0.0, scaling=self.scaling, sliding_window=self.sliding_window, diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index b1c85f5177ad..16d9cd2e464c 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -91,7 +91,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -192,6 +192,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -226,6 +227,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 5bcde06c46f2..b9a57e2f8160 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -134,7 +134,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -188,6 +188,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -222,6 +223,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 17691d72452c..24d224813966 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -211,7 +211,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -265,6 +265,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -299,6 +300,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 05c7bac26186..6fb7ee21572f 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -108,7 +108,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -153,6 +153,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -187,6 +188,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index a200d4808505..855c87c363ee 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -49,6 +49,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -83,6 +84,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 27226dc7417a..e05f9177124d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -221,7 +221,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -266,6 +266,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -300,6 +301,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index c518bc3a0fa6..02da8de50c1e 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -134,7 +134,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -187,7 +187,8 @@ def __init__(self, config: OlmoConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -208,7 +209,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -231,6 +232,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -247,6 +249,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.hidden_size = config.hidden_size self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) + self.mlp = OlmoMLP(config) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 6cd0f0d21ee8..0ca1135ec31c 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -53,7 +53,8 @@ class OlmoAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -74,7 +75,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -97,6 +98,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -110,10 +112,6 @@ def forward( class OlmoDecoderLayer(LlamaDecoderLayer): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__(config, layer_idx) - self.hidden_size = config.hidden_size - - self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) - self.mlp = OlmoMLP(config) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index d4a38f20ad1c..ccca0499c4c1 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -103,7 +103,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -158,11 +158,9 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -177,7 +175,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -200,6 +198,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -232,6 +231,7 @@ def __init__(self, config: Olmo2Config, layer_idx: int): self.hidden_size = config.hidden_size self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) + self.mlp = Olmo2MLP(config) self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -245,12 +245,13 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -258,6 +259,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -272,54 +274,8 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs - -OLMO2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Olmo2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", - OLMO2_START_DOCSTRING, -) -class Olmo2PreTrainedModel(PreTrainedModel): - config_class = Olmo2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Olmo2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + return outputs class Olmo2RotaryEmbedding(nn.Module): @@ -387,6 +343,51 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +OLMO2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Olmo2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo2 Model outputting raw hidden-states without any specific head on top.", + OLMO2_START_DOCSTRING, +) +class Olmo2PreTrainedModel(PreTrainedModel): + config_class = Olmo2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Olmo2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + OLMO2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -732,7 +733,7 @@ class Olmo2ForCausalLM(Olmo2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} - def __init__(self, config: Olmo2Config): + def __init__(self, config): super().__init__(config) self.model = Olmo2Model(config) self.vocab_size = config.vocab_size diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 74873d57b409..fab2db95f10f 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -173,11 +173,9 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -192,7 +190,7 @@ def forward( key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: @@ -215,6 +213,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, @@ -244,12 +243,13 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -257,6 +257,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = self.post_attention_layernorm(hidden_states) @@ -271,13 +272,8 @@ def forward( outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs - -class Olmo2PreTrainedModel(OlmoPreTrainedModel): - pass + return outputs # The OLMo2 model is identical to the OLMo model, except RMSNorm is used instead of @@ -285,17 +281,12 @@ class Olmo2PreTrainedModel(OlmoPreTrainedModel): class Olmo2Model(OlmoModel): def __init__(self, config: Olmo2Config): super().__init__(config) - self.layers = nn.ModuleList( - [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # The heads now only need to redefine the model inside to the correct `RobertaModel` class Olmo2ForCausalLM(OlmoForCausalLM): - def __init__(self, config: Olmo2Config): - super().__init__(config) - self.model = Olmo2Model(config) + pass __all__ = [ diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 0e409c0c953f..e9b946f24780 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -93,7 +93,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -157,6 +157,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -209,6 +210,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index b3f273298643..46fd51e9afcb 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -40,6 +40,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -92,6 +93,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8dd0659942b5..269075e1cfb5 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -108,7 +108,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -153,6 +153,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -195,6 +196,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 47a785a27b76..46c8b0f8c658 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -47,6 +47,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -89,6 +90,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=sliding_window, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 3c76732e0ee6..fcfbfd26bb26 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -127,7 +127,7 @@ def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, **kwargs, @@ -173,6 +173,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -207,6 +208,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index cb1122f4dbc4..f5124e07e5db 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -91,6 +91,7 @@ def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -125,6 +126,7 @@ def forward( query_states, key_states, value_states, + attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, From 1e8712b2adf5504902c6ff15d1dbe6e5c7dcd32a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 21:02:52 +0100 Subject: [PATCH 58/97] style --- src/transformers/models/gemma2/modular_gemma2.py | 2 +- src/transformers/models/olmo2/modular_olmo2.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 0cb9811e0e57..5dca7209f12e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -213,7 +213,7 @@ def eager_attention_forward( ) -> Tuple[torch.Tensor, torch.Tensor]: if scaling is None: scaling = module.head_dim**-0.5 - + key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index fab2db95f10f..6ff6e52912b3 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -1,7 +1,6 @@ from typing import Callable, Optional, Tuple import torch -from torch import nn from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -14,7 +13,6 @@ OlmoDecoderLayer, OlmoForCausalLM, OlmoModel, - OlmoPreTrainedModel, apply_rotary_pos_emb, ) @@ -293,5 +291,5 @@ class Olmo2ForCausalLM(OlmoForCausalLM): "Olmo2Config", "Olmo2ForCausalLM", "Olmo2Model", - "Olmo2PreTrainedModel", + "Olmo2PreTrainedModel", # noqa: F822 ] From 854537b2b8349b3bd95c9919392eab9ffb3bb2d7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 16 Dec 2024 23:35:55 +0100 Subject: [PATCH 59/97] fix --- .../integrations/flash_attention.py | 8 +- .../integrations/sdpa_attention.py | 5 +- .../modeling_decision_transformer.py | 74 ++++++++++++------- src/transformers/models/gpt2/modeling_gpt2.py | 74 ++++++++++++------- 4 files changed, 97 insertions(+), 64 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index b33286627da9..2e48902a4ba6 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -22,12 +22,8 @@ def flash_attention_forward( target_dtype: torch.dtype = torch.float16, **kwargs, ) -> Tuple[torch.Tensor, None]: - if attention_mask is not None: - seq_len = attention_mask.shape[1] - query = query[:, :, :seq_len] - value = value[:, :, :seq_len] - else: - seq_len = query.shape[1] + # This is before the transpose + seq_len = query.shape[2] # FA2 uses non-transposed inputs query = query.transpose(1, 2) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 805b379d2e18..9ae94ccf0f44 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -25,8 +25,9 @@ def sdpa_attention_forward( scaling: Optional[float] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 46f539430e73..42f5cd332adc 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -101,22 +101,22 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): # Copied from transformers.models.gpt2.modeling_gpt2.eager_attention_forward -def eager_attention_forward(self, query, key, value, attention_mask=None, head_mask=None): +def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): attn_weights = torch.matmul(query, key.transpose(-1, -2)) - if self.scale_attn_weights: + if module.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device ) # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) + if module.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(module.layer_idx + 1) - if not self.is_cross_attention: + if not module.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` @@ -131,7 +131,7 @@ def eager_attention_forward(self, query, key, value, attention_mask=None, head_m # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) + attn_weights = module.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: @@ -258,13 +258,13 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - past_key_value: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: @@ -274,29 +274,34 @@ def forward( "Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_states = self.q_attn(hidden_states) + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query.view(hidden_shape).transpose(1, 2) - key_states = key.view(hidden_shape).transpose(1, 2) - value_states = value.view(hidden_shape).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = position_embeddings + if layer_past is not None: + past_key, past_value = layer_past + key_states = torch.cat((past_key, key_states), dim=-2) + value_states = torch.cat((past_value, value_states), dim=-2) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if use_cache is True: + present = (key_states, value_states) + else: + present = None + using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): + using_eager = True logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' @@ -304,21 +309,34 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + if using_eager and self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query_states, key_states, value_states, attention_mask, head_mask + ) else: attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, + attention_mask, + head_mask=head_mask, + dropout=self.attn_dropout.p if self.training else 0.0, **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.c_proj(attn_output) + attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.c_proj(attn_output_reshaped) attn_output = self.resid_dropout(attn_output) - return attn_output, attn_weights + + outputs = (attn_output, present) + if output_attentions: + # weird but needed to satisfy BC and tests (would normally be None) + if self.config._attn_implementation == "flash_attention_2": + attn_weights = attn_output_reshaped + outputs += (attn_weights,) + + return outputs # a, present, (attentions) # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2 @@ -405,7 +423,7 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 45ca2fcbded7..b824fef9a42a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -117,22 +117,22 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model -def eager_attention_forward(self, query, key, value, attention_mask=None, head_mask=None): +def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): attn_weights = torch.matmul(query, key.transpose(-1, -2)) - if self.scale_attn_weights: + if module.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device ) # Layer-wise attention scaling - if self.scale_attn_by_inverse_layer_idx: - attn_weights = attn_weights / float(self.layer_idx + 1) + if module.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(module.layer_idx + 1) - if not self.is_cross_attention: + if not module.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` @@ -147,7 +147,7 @@ def eager_attention_forward(self, query, key, value, attention_mask=None, head_m # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) + attn_weights = module.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: @@ -273,13 +273,13 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - past_key_value: Optional[Tuple[torch.Tensor]] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: if encoder_hidden_states is not None: @@ -289,29 +289,34 @@ def forward( "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) - query = self.q_attn(hidden_states) - key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + query_states = self.q_attn(hidden_states) + key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query.view(hidden_shape).transpose(1, 2) - key_states = key.view(hidden_shape).transpose(1, 2) - value_states = value.view(hidden_shape).transpose(1, 2) + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) - cos, sin = position_embeddings + if layer_past is not None: + past_key, past_value = layer_past + key_states = torch.cat((past_key, key_states), dim=-2) + value_states = torch.cat((past_value, value_states), dim=-2) - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if use_cache is True: + present = (key_states, value_states) + else: + present = None + using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): + using_eager = True logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' @@ -319,21 +324,34 @@ def forward( else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - if self.reorder_and_upcast_attn: - attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + if using_eager and self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn( + query_states, key_states, value_states, attention_mask, head_mask + ) else: attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, + attention_mask, + head_mask=head_mask, + dropout=self.attn_dropout.p if self.training else 0.0, **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.c_proj(attn_output) + attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.c_proj(attn_output_reshaped) attn_output = self.resid_dropout(attn_output) - return attn_output, attn_weights + + outputs = (attn_output, present) + if output_attentions: + # weird but needed to satisfy BC and tests (would normally be None) + if self.config._attn_implementation == "flash_attention_2": + attn_weights = attn_output_reshaped + outputs += (attn_weights,) + + return outputs # a, present, (attentions) class GPT2MLP(nn.Module): @@ -415,7 +433,7 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) From 99bddf016a2f86dc55968baa1d44bf7eaf0a0557 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 01:27:47 +0100 Subject: [PATCH 60/97] several fixes --- .../modular-transformers/modeling_dummy.py | 2 +- .../modeling_multimodal1.py | 2 +- .../modeling_my_new_model2.py | 2 +- .../modular-transformers/modeling_super.py | 2 +- src/transformers/models/aria/modeling_aria.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 65 +++++-------------- .../models/gemma/modeling_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 3 +- src/transformers/models/glm/modeling_glm.py | 2 +- .../models/granite/modeling_granite.py | 2 +- .../models/idefics/modeling_idefics.py | 12 ++-- .../models/llama/modeling_llama.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- .../models/olmo2/modeling_olmo2.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- .../models/phimoe/modeling_phimoe.py | 12 ++-- .../models/qwen2/modeling_qwen2.py | 2 +- .../modeling_recurrent_gemma.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- 21 files changed, 46 insertions(+), 80 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index c672788aee54..3dedf14f46a6 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -210,7 +210,7 @@ def __init__(self, config: DummyConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 39b221b0a603..c3844b402838 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -210,7 +210,7 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 0324e192d97f..e07b31ea89cd 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -210,7 +210,7 @@ def __init__(self, config: MyNewModel2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index a43dd0d4948e..afc45e92f490 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -210,7 +210,7 @@ def __init__(self, config: SuperConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 743367d7e520..c800f0627092 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -515,7 +515,7 @@ def __init__(self, config: AriaTextConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index d1fc36873e62..a342499698da 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -47,57 +47,25 @@ _CONFIG_FOR_DOC = "DbrxConfig" -# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx class DbrxRotaryEmbedding(nn.Module): - def __init__( - self, - config: DbrxConfig, - device=None, - ): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -105,11 +73,6 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -272,7 +235,11 @@ def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None): self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False ) self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_emb = DbrxRotaryEmbedding(config) + self.rotary_emb = DbrxRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) def forward( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 4d6c5698a9c3..043af2d85ba6 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -241,7 +241,7 @@ def __init__(self, config: GemmaConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 62d82ee92cc8..44cc4c31bb07 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -180,7 +180,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.query_pre_attn_scalar**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = self.config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( @@ -196,7 +196,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) self.attn_logit_softcapping = self.config.attn_logit_softcapping - self.attention_dropout = self.config.attention_dropout self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None def forward( diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 16d9cd2e464c..343ad4a4b035 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -174,7 +174,7 @@ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index b9a57e2f8160..8216caa8185b 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -168,7 +168,7 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = config.attention_multiplier - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index ea3818990c7c..b2ffbcbc6956 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -444,8 +444,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -453,8 +452,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -465,8 +465,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 24d224813966..68ecadc08ba0 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -245,7 +245,7 @@ def __init__(self, config: LlamaConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 6fb7ee21572f..2f467dc665b3 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -142,7 +142,7 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e05f9177124d..5b35cbaf61ce 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -255,7 +255,7 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 02da8de50c1e..203893ea9a03 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -168,7 +168,7 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index ccca0499c4c1..4f33223f025f 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -137,7 +137,7 @@ def __init__(self, config: Olmo2Config, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index e9b946f24780..8d5e25281829 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -127,7 +127,7 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear( diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 496322096444..cd54b226e1d8 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -186,8 +186,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -195,8 +194,9 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -207,8 +207,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 269075e1cfb5..6f43e57d119f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -142,7 +142,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index b9d720eb39dc..5ed4049e09ae 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -191,7 +191,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + cos, sin = self.rotary_emb(value_states, position_ids) # Partial rotary embedding query_rot, query_pass = torch.chunk(query_states, int(1 / self.partial_rotary_factor), dim=-1) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index fcfbfd26bb26..c85a41a9c68b 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -161,7 +161,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None): self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 - self.attention_droupout = config.attention_dropout + self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.use_bias) self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) From 2f666b35f143f463605f28d70bb4638580b2d0a5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 01:28:51 +0100 Subject: [PATCH 61/97] Update modeling_dbrx.py --- src/transformers/models/dbrx/modeling_dbrx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a342499698da..0d2c4297e0d4 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -26,7 +26,6 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, From bafa020d4ddb2e6b641c7482216fa620561ed4d5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 02:04:50 +0100 Subject: [PATCH 62/97] fix test --- src/transformers/models/phi/modeling_phi.py | 20 +++++--------------- src/transformers/models/phi/modular_phi.py | 8 ++++++-- tests/test_modeling_common.py | 11 +++++++++++ 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 8d5e25281829..5881411f7c17 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -129,21 +129,11 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) - self.dense = self.o_proj self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( @@ -217,7 +207,7 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.dense(attn_output) return attn_output, attn_weights diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 46fd51e9afcb..ee15fed8dfb8 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -25,8 +25,12 @@ class PhiAttention(LlamaAttention): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__(config, layer_idx) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + del self.o_proj self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) - self.dense = self.o_proj self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( @@ -100,7 +104,7 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.dense(attn_output) return attn_output, attn_weights diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index bc03855846fe..23c5f33b0802 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -43,6 +43,7 @@ logging, set_seed, ) +from transformers.cache_utils import DynamicCache from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( is_deepspeed_available, @@ -1285,6 +1286,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) + empty_pkv = ( + DynamicCache.from_legacy_cache(empty_pkv) + if model_class._supports_cache_class + else empty_pkv + ) cache_length = 9 cache_shape = (batch_size, num_heads, cache_length, head_dim) @@ -1295,6 +1301,11 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ) for i in range(model.config.num_hidden_layers) ) + non_empty_pkv = ( + DynamicCache.from_legacy_cache(empty_pkv) + if model_class._supports_cache_class + else non_empty_pkv + ) inps = copy.deepcopy(inputs_to_test[0]) From 00a98e7125fde87191ae8bb9ce137237fd90637b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 02:18:47 +0100 Subject: [PATCH 63/97] olmo + phi --- src/transformers/models/olmo/modeling_olmo.py | 22 +------------------ src/transformers/models/olmo/modular_olmo.py | 13 ++++++----- src/transformers/models/phi/modeling_phi.py | 12 +++++----- src/transformers/models/phi/modular_phi.py | 12 +++++----- 4 files changed, 20 insertions(+), 39 deletions(-) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 203893ea9a03..af0c73da1f9f 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -47,26 +47,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) -class OlmoRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - OlmoRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class OlmoMLP(nn.Module): def __init__(self, config): super().__init__() @@ -504,7 +484,7 @@ def __init__(self, config: OlmoConfig): self.layers = nn.ModuleList( [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = OlmoRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = OlmoLayerNorm(config.hidden_size) self.rotary_emb = OlmoRotaryEmbedding(config=config) self.gradient_checkpointing = False diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 0ca1135ec31c..e6e030800862 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -13,7 +13,7 @@ LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP, - LlamaRMSNorm, + LlamaModel, apply_rotary_pos_emb, eager_attention_forward, ) @@ -36,11 +36,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype ) - -class OlmoRMSNorm(LlamaRMSNorm): - pass - - class OlmoMLP(LlamaMLP): def __init__(self, config): super().__init__(config) @@ -116,5 +111,11 @@ def __init__(self, config: OlmoConfig, layer_idx: int): self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) +class OlmoModel(LlamaModel): + def __init__(self, config: OlmoConfig): + super().__init__(config) + self.norm = OlmoLayerNorm(config.hidden_size) + + class OlmoForCausalLM(LlamaForCausalLM): pass diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 5881411f7c17..d04635494abd 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -129,18 +129,18 @@ def __init__(self, config: PhiConfig, layer_idx: int): self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( - config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) self.k_layernorm = nn.LayerNorm( - config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) def forward( diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index ee15fed8dfb8..e7ff117159ba 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -25,19 +25,19 @@ class PhiAttention(LlamaAttention): def __init__(self, config: PhiConfig, layer_idx: int): super().__init__(config, layer_idx) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) + self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=True) + self.dense = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) del self.o_proj self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) self.qk_layernorm = config.qk_layernorm if self.qk_layernorm: self.q_layernorm = nn.LayerNorm( - config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) self.k_layernorm = nn.LayerNorm( - config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) def forward( From 8c25411205ca141c5f9a48725c49820aad85f125 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 02:47:32 +0100 Subject: [PATCH 64/97] rotary --- src/transformers/models/phi3/modeling_phi3.py | 67 +++++-------------- tests/models/falcon/test_modeling_falcon.py | 25 ++----- .../models/gpt_neox/test_modeling_gpt_neox.py | 25 ++----- .../persimmon/test_modeling_persimmon.py | 26 ++----- tests/models/phi/test_modeling_phi.py | 26 ++----- .../models/stablelm/test_modeling_stablelm.py | 27 ++------ 6 files changed, 48 insertions(+), 148 deletions(-) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index b3b8c6426bce..ed3576f35ff3 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -75,57 +75,27 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +# copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +# TODO cyril: modular class Phi3RotaryEmbedding(nn.Module): - def __init__( - self, - config: Phi3Config, - device=None, - ): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.rope_kwargs = {} - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -133,11 +103,6 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) @@ -375,7 +340,11 @@ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): def _init_rope(self): if self.rope_scaling is None: - self.rotary_emb = Phi3RotaryEmbedding(self.config) + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) else: scaling_type = self.config.rope_scaling["type"] if scaling_type == "longrope": diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index 129bd346a10d..f81c13e89c6b 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -453,6 +453,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size @@ -470,11 +471,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = FalconRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = FalconRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -482,13 +479,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = FalconRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = FalconRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -501,13 +493,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = FalconRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = FalconRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ca9fbb225c6d..3af8a2a99202 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -366,7 +366,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->GPTNeoX, rope_theta->rotary_emb_base def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size @@ -384,11 +383,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = GPTNeoXRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - ).to(torch_device) + original_rope = GPTNeoXRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -396,13 +391,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = GPTNeoXRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = GPTNeoXRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -415,13 +405,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = GPTNeoXRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rotary_emb_base, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = GPTNeoXRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index 54ee49b65343..a96a26e51a06 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -417,7 +417,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Persimmon + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Persimmon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size @@ -435,11 +435,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = PersimmonRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = PersimmonRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -447,13 +443,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = PersimmonRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -466,13 +457,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = PersimmonRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = PersimmonRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index df5278cb34e3..7fff4a6f724b 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -396,7 +396,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Phi def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size @@ -414,11 +414,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = PhiRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = PhiRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -426,13 +422,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = PhiRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = PhiRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -445,13 +436,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = PhiRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = PhiRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index bfab01578229..f9e0e6955175 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -402,7 +402,8 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->StableLm + + # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->StableLm def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size @@ -420,11 +421,7 @@ def test_model_rope_scaling(self): position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE - original_rope = StableLmRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - ).to(torch_device) + original_rope = StableLmRotaryEmbedding(config).to(torch_device) original_cos_short, original_sin_short = original_rope(x, position_ids_short) original_cos_long, original_sin_long = original_rope(x, position_ids_long) torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) @@ -432,13 +429,8 @@ def test_model_rope_scaling(self): # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" - linear_scaling_rope = StableLmRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="linear", - ).to(torch_device) + config.rope_scaling = {"type": "linear", "factor": scaling_factor} + linear_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device) linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) @@ -451,13 +443,8 @@ def test_model_rope_scaling(self): # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase # with scaling_factor (or that `inv_freq` decreases) - ntk_scaling_rope = StableLmRotaryEmbedding( - head_dim, - max_position_embeddings=config.max_position_embeddings, - base=config.rope_theta, - scaling_factor=scaling_factor, - rope_type="dynamic", - ).to(torch_device) + config.rope_scaling = {"type": "dynamic", "factor": scaling_factor} + ntk_scaling_rope = StableLmRotaryEmbedding(config).to(torch_device) ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) From 4bb2f257f93f25ac48fdb182c98cb37ebe0911f7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 02:49:47 +0100 Subject: [PATCH 65/97] syle --- src/transformers/models/olmo/modular_olmo.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 - tests/models/falcon/test_modeling_falcon.py | 3 --- tests/models/gpt_neox/test_modeling_gpt_neox.py | 3 --- tests/models/persimmon/test_modeling_persimmon.py | 3 --- tests/models/phi/test_modeling_phi.py | 3 --- tests/models/stablelm/test_modeling_stablelm.py | 4 ---- 7 files changed, 1 insertion(+), 17 deletions(-) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index e6e030800862..435ad5e72e61 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -36,6 +36,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_dtype ) + class OlmoMLP(LlamaMLP): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ed3576f35ff3..908fd982b9c7 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -35,7 +35,6 @@ SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, diff --git a/tests/models/falcon/test_modeling_falcon.py b/tests/models/falcon/test_modeling_falcon.py index f81c13e89c6b..3ad46a92bc09 100644 --- a/tests/models/falcon/test_modeling_falcon.py +++ b/tests/models/falcon/test_modeling_falcon.py @@ -456,9 +456,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Falcon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 3af8a2a99202..6d5e081d50b1 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -368,9 +368,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) diff --git a/tests/models/persimmon/test_modeling_persimmon.py b/tests/models/persimmon/test_modeling_persimmon.py index a96a26e51a06..e783cea95a63 100644 --- a/tests/models/persimmon/test_modeling_persimmon.py +++ b/tests/models/persimmon/test_modeling_persimmon.py @@ -420,9 +420,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Persimmon def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index 7fff4a6f724b..c7b59d278e4f 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -399,9 +399,6 @@ def test_model_rope_scaling_from_config(self, scaling_type): # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->Phi def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index f9e0e6955175..c8aa55399035 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -402,13 +402,9 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.gpt_neox.test_modeling_gpt_neox.GPTNeoXModelTest.test_model_rope_scaling with GPTNeoX->StableLm def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - hidden_size = config.hidden_size - num_heads = config.num_attention_heads - head_dim = hidden_size // num_heads scaling_factor = 10 short_input_length = 10 long_input_length = int(config.max_position_embeddings * 1.5) From 44ff5e3d97faadf02a9d619275a4852f43c945c1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 03:08:01 +0100 Subject: [PATCH 66/97] phi --- src/transformers/models/phi/modeling_phi.py | 19 +++++++++---------- src/transformers/models/phi/modular_phi.py | 4 ++-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index d04635494abd..425dc04299dd 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -215,16 +215,15 @@ class PhiMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states class PhiDecoderLayer(nn.Module): diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index e7ff117159ba..eb1765a58a7b 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -12,10 +12,10 @@ LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, - LlamaMLP, apply_rotary_pos_emb, eager_attention_forward, # copied from Llama ) +from ..clip.modeling_clip import CLIPMLP from .configuration_phi import PhiConfig @@ -108,7 +108,7 @@ def forward( return attn_output, attn_weights -class PhiMLP(LlamaMLP): +class PhiMLP(CLIPMLP): pass From 95f7b9637eb792853a7071b2e2dbeac769c99baa Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 03:22:56 +0100 Subject: [PATCH 67/97] phi again --- src/transformers/models/phi/modeling_phi.py | 30 +---- src/transformers/models/phi/modular_phi.py | 128 +++++++++++++++++++- 2 files changed, 129 insertions(+), 29 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 425dc04299dd..85280dffc4e4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -9,10 +9,8 @@ import torch import torch.nn as nn -from transformers.cache_utils import Cache - from ...activations import ACT2FN -from ...cache_utils import DynamicCache, StaticCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -276,26 +274,6 @@ def forward( return outputs -class PhiRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - PhiRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - class PhiRotaryEmbedding(nn.Module): def __init__( self, @@ -502,9 +480,10 @@ def __init__(self, config: PhiConfig): self.layers = nn.ModuleList( [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) - self.norm = PhiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = PhiRotaryEmbedding(config=config) self.gradient_checkpointing = False + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights and apply final processing self.post_init() @@ -565,6 +544,7 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) + inputs_embeds = self.embed_dropout(inputs_embeds) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -608,7 +588,7 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.norm(hidden_states) + hidden_states = self.final_layernorm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index eb1765a58a7b..52b2c9ba00ca 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -1,21 +1,26 @@ -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn -from transformers.cache_utils import Cache - +from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPast, +) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...utils import logging +from ..clip.modeling_clip import CLIPMLP from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, + LlamaModel, apply_rotary_pos_emb, eager_attention_forward, # copied from Llama ) -from ..clip.modeling_clip import CLIPMLP from .configuration_phi import PhiConfig @@ -162,6 +167,121 @@ def forward( return outputs +class PhiModel(LlamaModel): + def __init__(self, config: PhiConfig): + super().__init__(config) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + del self.norm + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + inputs_embeds = self.embed_dropout(inputs_embeds) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + class PhiForCausalLM(LlamaForCausalLM): pass From 7d550361a0250244b67e03dc4b54c73bab267e18 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 03:31:05 +0100 Subject: [PATCH 68/97] again --- src/transformers/models/phi/modeling_phi.py | 5 +---- src/transformers/models/phi/modular_phi.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 85280dffc4e4..790e16a935b6 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -249,7 +249,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( + attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -268,9 +268,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 52b2c9ba00ca..fb6f1dd7b599 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -142,7 +142,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( + attn_outputs, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -161,9 +161,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs From 24ac9ab8dc6d1bc6e378afb779c260af62cf2a47 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 03:32:23 +0100 Subject: [PATCH 69/97] kwargs --- src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi/modular_phi.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 790e16a935b6..a86ee6fda19f 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -258,6 +258,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index fb6f1dd7b599..bcaa2da1b486 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -151,6 +151,7 @@ def forward( use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, + **kwargs, ) attn_outputs = self.resid_dropout(attn_outputs) From bd8ede8a0ce539eb03bf8a3c99dcd716fd427028 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 03:37:06 +0100 Subject: [PATCH 70/97] Update test_modeling_common.py --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 23c5f33b0802..e2912ff1ebf4 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1302,7 +1302,7 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa for i in range(model.config.num_hidden_layers) ) non_empty_pkv = ( - DynamicCache.from_legacy_cache(empty_pkv) + DynamicCache.from_legacy_cache(non_empty_pkv) if model_class._supports_cache_class else non_empty_pkv ) From 0d3d3e392d85b80635698f8753c29d97cecdad4a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 10:58:46 +0100 Subject: [PATCH 71/97] skip fx tracing tests --- src/transformers/models/aria/modeling_aria.py | 1 - src/transformers/models/gemma/modeling_gemma.py | 1 - src/transformers/models/glm/modeling_glm.py | 1 - src/transformers/models/llama/modeling_llama.py | 1 - src/transformers/models/mistral/modeling_mistral.py | 1 - src/transformers/models/olmo/modeling_olmo.py | 1 - src/transformers/models/qwen2/modeling_qwen2.py | 1 - src/transformers/models/starcoder2/modeling_starcoder2.py | 1 - tests/models/gpt2/test_modeling_gpt2.py | 2 +- tests/models/llama/test_modeling_llama.py | 6 +----- tests/models/mistral/test_modeling_mistral.py | 2 +- tests/models/mixtral/test_modeling_mixtral.py | 2 +- tests/models/qwen2/test_modeling_qwen2.py | 2 +- tests/models/qwen2_moe/test_modeling_qwen2_moe.py | 2 +- 14 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index c800f0627092..0af8d2888a9d 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -640,7 +640,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 043af2d85ba6..df48af4e007f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -355,7 +355,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 343ad4a4b035..caab4ce0a778 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -371,7 +371,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 68ecadc08ba0..de67315f5ff8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -359,7 +359,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 2f467dc665b3..a2c9a5f635f1 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -333,7 +333,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index af0c73da1f9f..09d138f7056b 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -271,7 +271,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6f43e57d119f..2815df8f8f59 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -281,7 +281,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index c85a41a9c68b..993748be78f5 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -268,7 +268,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 012444b472c0..88ccdc8ee45a 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -507,7 +507,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin else {} ) all_parallelizable_model_classes = (GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else () - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez test_missing_keys = False test_model_parallel = True diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0790de4e133b..78e42e6ba71f 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -308,7 +308,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # Need to use `0.8` instead of `0.9` for `test_cpu_offload` # This is because we are hitting edge cases with the causal_mask buffer @@ -571,10 +571,6 @@ def test_use_flash_attention_2_true(self): if not has_flash: raise ValueError("The flash model should have flash attention layers") - @unittest.skip("Broken by the loss update will fix soon @ArthurZucker") - def test_torch_fx_output_loss(self, *args, **kwargs): - pass - @require_torch_gpu class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index c5ea050edf92..d9e6b9d7bfe7 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -316,7 +316,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index 931bb1f17bec..9abbf444d0b0 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -314,7 +314,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 6c32a66e0362..ecfa9189d12e 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -327,7 +327,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index abc7b57919b0..21d11047ff1b 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -352,7 +352,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ) test_headmasking = False test_pruning = False - fx_compatible = True + fx_compatible = False # Broken by attention refactor cc @Cyrilvallez # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 def is_pipeline_test_to_skip( From 49135d0479685a727d04524deed22e71a81455fe Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:09:28 +0100 Subject: [PATCH 72/97] Update modeling_utils.py --- src/transformers/modeling_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ee576f69cfe4..ed1c497553fb 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1545,8 +1545,6 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) - elif config._attn_implementation in ALL_ATTENTION_FUNCTIONS: - pass elif isinstance(requested_attn_implementation, dict): config._attn_implementation = None else: From f80a2c33e96f1ef22e25da3832ba41c2fe7d4cc9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:24:40 +0100 Subject: [PATCH 73/97] gemma 2 --- .../modular-transformers/modeling_dummy.py | 1 - .../modeling_multimodal1.py | 1 - .../modeling_my_new_model2.py | 1 - .../modular-transformers/modeling_super.py | 1 - .../models/gemma2/modeling_gemma2.py | 16 +++++++----- .../models/gemma2/modular_gemma2.py | 25 ++++++++----------- 6 files changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 3dedf14f46a6..b33ce10ae9e4 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -324,7 +324,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index c3844b402838..2ce2036e5645 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -324,7 +324,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index e07b31ea89cd..21c55efc4dbc 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -324,7 +324,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index afc45e92f490..b0a863efd712 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -324,7 +324,6 @@ def forward( hidden_states = residual + hidden_states outputs = (hidden_states,) - if output_attentions: outputs += (self_attn_weights,) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 44cc4c31bb07..771617b2a750 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -268,6 +268,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -296,6 +297,7 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -601,6 +603,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -619,6 +624,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -629,6 +635,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -647,16 +654,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = past_key_values if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 5dca7209f12e..6d9dd1be5c50 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -314,6 +314,7 @@ def __init__(self, config: Gemma2Config, layer_idx: int): def forward( self, hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -342,6 +343,7 @@ def forward( # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, + position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, @@ -373,14 +375,7 @@ class Gemma2PreTrainedModel(GemmaPreTrainedModel): pass -class Gemma2Model(GemmaModel, Gemma2PreTrainedModel): - def __init__(self, config: Gemma2Config): - super().__init__(config) - self.layers = nn.ModuleList( - [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.post_init() - +class Gemma2Model(GemmaModel): def forward( self, input_ids: torch.LongTensor = None, @@ -439,6 +434,9 @@ def forward( # embed positions hidden_states = inputs_embeds + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + # normalized # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 # See https://github.com/huggingface/transformers/pull/29402 @@ -457,6 +455,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, causal_mask, position_ids, past_key_values, @@ -467,6 +466,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, + position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, @@ -485,16 +485,13 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = past_key_values if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=next_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, ) + return output if return_dict else output.to_tuple() @torch.no_grad() def _update_causal_mask( From 3e461bd19d96ee52d58b2e1d703bed251176e208 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:30:42 +0100 Subject: [PATCH 74/97] again --- .../models/gemma2/modeling_gemma2.py | 93 ++++++++++--------- .../models/gemma2/modular_gemma2.py | 19 +--- 2 files changed, 50 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 771617b2a750..83b0af97a7d2 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -86,7 +86,8 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_activation] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj def rotate_half(x): @@ -325,51 +326,6 @@ def forward( return outputs -GEMMA2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Gemma2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", - GEMMA2_START_DOCSTRING, -) -class Gemma2PreTrainedModel(PreTrainedModel): - config_class = Gemma2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Gemma2DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - class Gemma2RotaryEmbedding(nn.Module): def __init__( self, @@ -435,6 +391,51 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +GEMMA2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Gemma2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", + GEMMA2_START_DOCSTRING, +) +class Gemma2PreTrainedModel(PreTrainedModel): + config_class = Gemma2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Gemma2DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + GEMMA2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 6d9dd1be5c50..2965ea649edd 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -35,8 +35,8 @@ GemmaForCausalLM, GemmaForSequenceClassification, GemmaForTokenClassification, + GemmaMLP, GemmaModel, - GemmaPreTrainedModel, GemmaRMSNorm, apply_rotary_pos_emb, repeat_kv, @@ -185,20 +185,11 @@ class Gemma2RMSNorm(GemmaRMSNorm): pass -class Gemma2MLP(nn.Module): +class Gemma2MLP(GemmaMLP): def __init__(self, config): super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_activation] - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - def eager_attention_forward( module: nn.Module, @@ -371,10 +362,6 @@ def forward( return outputs -class Gemma2PreTrainedModel(GemmaPreTrainedModel): - pass - - class Gemma2Model(GemmaModel): def forward( self, @@ -712,7 +699,7 @@ def __init__(self, config): "Gemma2Config", "Gemma2ForCausalLM", "Gemma2Model", - "Gemma2PreTrainedModel", + "Gemma2PreTrainedModel", # noqa: F822 "Gemma2ForSequenceClassification", "Gemma2ForTokenClassification", ] From 7a882d559a52a00022275d703bef119cb8733564 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:33:22 +0100 Subject: [PATCH 75/97] Update modeling_recurrent_gemma.py --- .../recurrent_gemma/modeling_recurrent_gemma.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 5ed4049e09ae..74fc2085c365 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -77,15 +77,13 @@ def __init__(self, dim, base=10000, device=None): self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) @torch.no_grad() - # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->RecurrentGemma - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): @@ -93,11 +91,6 @@ def forward(self, x, position_ids): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) From 78700734bc0ed3b0991732933da8175a4cc8d4f8 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:37:27 +0100 Subject: [PATCH 76/97] gemma2 --- src/transformers/models/gemma2/modeling_gemma2.py | 5 +---- src/transformers/models/gemma2/modular_gemma2.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 83b0af97a7d2..67fc6c86a3ba 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -296,7 +296,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -320,9 +320,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 2965ea649edd..f9ec305b2272 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -332,7 +332,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, @@ -356,9 +356,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs From 5b4ebaade5f1f1b3d301d87fe425bc2fc92afffc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:53:48 +0100 Subject: [PATCH 77/97] granite --- .../models/granite/modeling_granite.py | 78 +++++++++---------- .../models/granite/modular_granite.py | 26 +------ tests/test_modeling_common.py | 2 +- 3 files changed, 41 insertions(+), 65 deletions(-) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 8216caa8185b..995fe8e6c916 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -47,42 +47,6 @@ _CONFIG_FOR_DOC = "GraniteConfig" -class GraniteRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - GraniteRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -class GraniteMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -234,6 +198,42 @@ def forward( return attn_output, attn_weights +class GraniteRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + GraniteRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class GraniteMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + class GraniteDecoderLayer(nn.Module): def __init__(self, config: GraniteConfig, layer_idx: int): super().__init__() @@ -244,7 +244,6 @@ def __init__(self, config: GraniteConfig, layer_idx: int): self.mlp = GraniteMLP(config) self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.residual_multiplier = config.residual_multiplier def forward( @@ -286,7 +285,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -310,9 +309,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index e276fa5ef1f1..9449f7aeea66 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -23,21 +23,12 @@ from ...utils import ( logging, ) -from ..llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaMLP, LlamaRMSNorm +from ..llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaDecoderLayer from .configuration_granite import GraniteConfig logger = logging.get_logger(__name__) - -class GraniteRMSNorm(LlamaRMSNorm): - pass - - -class GraniteMLP(LlamaMLP): - pass - - class GraniteAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -46,17 +37,9 @@ def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): self.scaling = config.attention_multiplier -class GraniteDecoderLayer(nn.Module): +class GraniteDecoderLayer(LlamaDecoderLayer): def __init__(self, config: GraniteConfig, layer_idx: int): super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) - - self.mlp = GraniteMLP(config) - self.input_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.residual_multiplier = config.residual_multiplier def forward( @@ -98,7 +81,7 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, @@ -122,9 +105,6 @@ def forward( if output_attentions: outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) - return outputs diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e2912ff1ebf4..fe4005b3f737 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2713,7 +2713,7 @@ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """ Args: model_class: The class of the model that is currently testing. For example, ..., etc. From 7bdf61c6f779bca27eed60c4b5b4e6cbfbf83f5e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 11:54:36 +0100 Subject: [PATCH 78/97] style --- src/transformers/models/granite/modular_granite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 9449f7aeea66..61479c81bd23 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -17,18 +17,18 @@ import torch import torch.utils.checkpoint -from torch import nn from ...cache_utils import Cache from ...utils import ( logging, ) -from ..llama.modeling_llama import LlamaAttention, LlamaForCausalLM, LlamaDecoderLayer +from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM from .configuration_granite import GraniteConfig logger = logging.get_logger(__name__) + class GraniteAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" From 7d5b0b53ce5072278f9f68a6ac01c41b3c691d3a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 12:21:04 +0100 Subject: [PATCH 79/97] starcoder --- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 ++ src/transformers/models/starcoder2/modular_starcoder2.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 993748be78f5..be09aade9de7 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -211,9 +211,11 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index f5124e07e5db..8fc041dd422a 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -129,9 +129,11 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) From 70ef2fd935928c24c7ca6b95b29bd442678234c2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 14:16:48 +0100 Subject: [PATCH 80/97] Update sdpa_attention.py --- src/transformers/integrations/sdpa_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 9ae94ccf0f44..7baa0da1b7f3 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -37,7 +37,7 @@ def sdpa_attention_forward( key = key.contiguous() value = value.contiguous() - is_causal = True if causal_mask is None and query.shape[1] > 1 else False + is_causal = causal_mask is None and query.shape[2] > 1 attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, From b8429c5efa2aa0a5e113af1708e1a4d138edb7bd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 14:43:53 +0100 Subject: [PATCH 81/97] switch args --- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi/modular_phi.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index a86ee6fda19f..c4786a980b5a 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -237,9 +237,9 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index bcaa2da1b486..05676a0deb54 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -130,9 +130,9 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, From 533657c9a8d4441d7dc7f5225440767aabdc3d09 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 14:46:20 +0100 Subject: [PATCH 82/97] Update modeling_mllama.py --- src/transformers/models/mllama/modeling_mllama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d53a80dd8929..3e0c4d7a5123 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -829,7 +829,8 @@ def __init__(self, config): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer From fe20d63afeb591b4d431b896ad3e8290e5386b75 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 15:06:59 +0100 Subject: [PATCH 83/97] fix --- .../models/granite/modeling_granite.py | 4 + .../models/granite/modular_granite.py | 188 +++++++++++++++++- utils/check_config_attributes.py | 3 + 3 files changed, 188 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 995fe8e6c916..5c2f3d85ecb9 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -521,6 +521,7 @@ def __init__(self, config: GraniteConfig): self.norm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GraniteRotaryEmbedding(config=config) self.gradient_checkpointing = False + self.embedding_multiplier = config.embedding_multiplier # Initialize weights and apply final processing self.post_init() @@ -565,6 +566,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds * self.embedding_multiplier + if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -866,6 +869,7 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = logits / self.config.logits_scaling loss = None if labels is not None: diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 61479c81bd23..8f683e9a46ec 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -13,16 +13,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint -from ...cache_utils import Cache -from ...utils import ( - logging, -) -from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM +from ...cache_utils import Cache, DynamicCache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...processing_utils import Unpack +from ...utils import LossKwargs, logging +from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from .configuration_granite import GraniteConfig @@ -108,5 +109,178 @@ def forward( return outputs +class GraniteModel(LlamaModel): + def __init__(self, config: GraniteConfig): + super().__init__(config) + self.embedding_multiplier = config.embedding_multiplier + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embedding_multiplier + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + class GraniteForCausalLM(LlamaForCausalLM): - pass + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + logits = logits / self.config.logits_scaling + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index a125387ff292..7b2551356c97 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -307,6 +307,9 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "backbone_config", "use_timm_backbone", "backbone_kwargs", + "pretraining_tp", + "rope_theta", + "partial_rotary_factor", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] From 248a60725fd62330c04ed5260f113157632fed25 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 15:55:48 +0100 Subject: [PATCH 84/97] cache type tests --- tests/test_modeling_common.py | 7 +++++-- tests/test_modeling_flax_common.py | 6 ++++-- tests/test_modeling_tf_common.py | 5 ++++- utils/check_config_attributes.py | 3 ++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fe4005b3f737..24e035279b4e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2482,7 +2482,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla return new_tf_outputs, new_pt_outputs # Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs - def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way. Args: @@ -2538,6 +2538,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))]) for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(tf_outputs, tf.Tensor): @@ -2723,7 +2725,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n Currently unused, but in the future, we could use this information to make the error message clearer by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. """ - self.assertEqual(type(name), str) if attributes is not None: self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") @@ -2768,6 +2769,8 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, n attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(fx_outputs, jnp.ndarray): diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index c7d098be3ea8..ae4cccafbcb8 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -22,6 +22,7 @@ import numpy as np import transformers +from transformers.cache_utils import DynamicCache from transformers import is_flax_available, is_torch_available from transformers.models.auto import get_values from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device @@ -180,7 +181,7 @@ def recursive_check(tuple_object, dict_object): check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) # (Copied from tests.test_modeling_common.ModelTesterMixin.check_pt_flax_outputs) - def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """ Args: model_class: The class of the model that is currently testing. For example, ..., etc. @@ -190,7 +191,6 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, n Currently unused, but in the future, we could use this information to make the error message clearer by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. """ - self.assertEqual(type(name), str) if attributes is not None: self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`") @@ -235,6 +235,8 @@ def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, n attributes = tuple([f"{name}_{idx}" for idx in range(len(fx_outputs))]) for fx_output, pt_output, attr in zip(fx_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_flax_outputs(fx_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(fx_outputs, jnp.ndarray): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index eb328d83e9e7..9dc712ab67b6 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -484,7 +484,7 @@ def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_cla return new_tf_outputs, new_pt_outputs - def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): + def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-4, name="outputs", attributes=None): """Check the outputs from PyTorch and TensorFlow models are close enough. Checks are done in a recursive way. Args: @@ -495,6 +495,7 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element being a named field in the output. """ + from transformers.cache_utils import DynamicCache self.assertEqual(type(name), str) if attributes is not None: @@ -540,6 +541,8 @@ def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, nam attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))]) for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes): + if isinstance(pt_output, DynamicCache): + pt_output = pt_output.to_legacy_cache() self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr) elif isinstance(tf_outputs, tf.Tensor): diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 7b2551356c97..4b8fd212d2a9 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -41,6 +41,7 @@ "expert_layer_offset", "expert_layer_period", ], + "LlamaConfig": ["pretraining_tp"], "Qwen2Config": ["use_sliding_window"], "Qwen2MoeConfig": ["use_sliding_window"], "Qwen2VLConfig": ["use_sliding_window"], @@ -307,7 +308,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s "backbone_config", "use_timm_backbone", "backbone_kwargs", - "pretraining_tp", + # rope attributes may not appear directly in the modeling but are used "rope_theta", "partial_rotary_factor", ] From 4646014224e9cc32c5cbea8fa4b4eb3a8c92f023 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 18:01:52 +0100 Subject: [PATCH 85/97] gpt2 --- .../modeling_decision_transformer.py | 9 ++++----- src/transformers/models/gpt2/modeling_gpt2.py | 9 ++++----- tests/test_modeling_flax_common.py | 2 +- utils/check_config_attributes.py | 2 +- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 42f5cd332adc..92fa06d2c941 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -138,6 +138,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -252,6 +253,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -325,15 +327,12 @@ def forward( **kwargs, ) - attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.c_proj(attn_output_reshaped) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: - # weird but needed to satisfy BC and tests (would normally be None) - if self.config._attn_implementation == "flash_attention_2": - attn_weights = attn_output_reshaped outputs += (attn_weights,) return outputs # a, present, (attentions) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b824fef9a42a..fbf1367698ac 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -154,6 +154,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -267,6 +268,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2) return attn_output, attn_weights @@ -340,15 +342,12 @@ def forward( **kwargs, ) - attn_output_reshaped = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.c_proj(attn_output_reshaped) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: - # weird but needed to satisfy BC and tests (would normally be None) - if self.config._attn_implementation == "flash_attention_2": - attn_weights = attn_output_reshaped outputs += (attn_weights,) return outputs # a, present, (attentions) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index ae4cccafbcb8..bfe1648de049 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -22,8 +22,8 @@ import numpy as np import transformers -from transformers.cache_utils import DynamicCache from transformers import is_flax_available, is_torch_available +from transformers.cache_utils import DynamicCache from transformers.models.auto import get_values from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index 4b8fd212d2a9..420d6e6a2475 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -41,7 +41,6 @@ "expert_layer_offset", "expert_layer_period", ], - "LlamaConfig": ["pretraining_tp"], "Qwen2Config": ["use_sliding_window"], "Qwen2MoeConfig": ["use_sliding_window"], "Qwen2VLConfig": ["use_sliding_window"], @@ -311,6 +310,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s # rope attributes may not appear directly in the modeling but are used "rope_theta", "partial_rotary_factor", + "pretraining_tp", ] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] From ad16b1bd0ae0f536030d29e008313ba4a3974e94 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 18:25:52 +0100 Subject: [PATCH 86/97] Update test_modeling_common.py --- tests/test_modeling_common.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 24e035279b4e..19ae0adec6dc 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -43,7 +43,6 @@ logging, set_seed, ) -from transformers.cache_utils import DynamicCache from transformers.integrations import HfDeepSpeedConfig from transformers.integrations.deepspeed import ( is_deepspeed_available, @@ -122,6 +121,7 @@ from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage + from transformers.cache_utils import DynamicCache if is_tf_available(): @@ -3890,6 +3890,11 @@ def test_sdpa_can_dispatch_non_composite_models(self): model_eager = model_eager.eval().to(torch_device) self.assertTrue(model_eager.config._attn_implementation == "eager") + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + @require_torch_sdpa def test_sdpa_can_dispatch_composite_models(self): """ @@ -3942,15 +3947,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]): - raise ValueError("The SDPA model should have SDPA attention layers") - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa def test_eager_matches_sdpa_inference(self, torch_dtype: str): From 1df6e29bbf91b6a293c86f6c3116a8d01411b8d6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 18:45:03 +0100 Subject: [PATCH 87/97] fix --- src/transformers/models/gpt2/modeling_gpt2.py | 6 +++--- .../encoder_decoder/test_modeling_encoder_decoder.py | 9 --------- tests/models/qwen2_audio/test_modeling_qwen2_audio.py | 9 --------- .../test_modeling_speech_encoder_decoder.py | 9 --------- .../test_modeling_vision_encoder_decoder.py | 9 --------- tests/test_modeling_common.py | 2 +- 6 files changed, 4 insertions(+), 40 deletions(-) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index fbf1367698ac..ad91eee9d9dc 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -300,9 +300,9 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(hidden_shape).transpose(1, 2) - key_states = key_states.view(hidden_shape).transpose(1, 2) - value_states = value_states.view(hidden_shape).transpose(1, 2) + query_states = query_states.reshape(hidden_shape).transpose(1, 2) + key_states = key_states.reshape(hidden_shape).transpose(1, 2) + value_states = value_states.reshape(hidden_shape).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 64ebedcb4598..1c4051f2e264 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -733,15 +733,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 42b521e518e2..4806ec2c72d3 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -206,15 +206,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 7dcb7c406ae2..897d4b056f19 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -500,15 +500,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 77e2a19fea48..2b517034bffb 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -441,15 +441,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 19ae0adec6dc..4cf62d23c1cf 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -119,9 +119,9 @@ from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding + from transformers.cache_utils import DynamicCache from transformers.modeling_utils import load_state_dict, no_init_weights from transformers.pytorch_utils import id_tensor_storage - from transformers.cache_utils import DynamicCache if is_tf_available(): From 6c01005c332b3c9e15d93a49e1e3d215f614dfb4 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 18:58:14 +0100 Subject: [PATCH 88/97] consistency --- .../modeling_decision_transformer.py | 6 +++--- tests/models/idefics2/test_modeling_idefics2.py | 9 --------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 92fa06d2c941..f399ff447920 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -285,9 +285,9 @@ def forward( input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(hidden_shape).transpose(1, 2) - key_states = key_states.view(hidden_shape).transpose(1, 2) - value_states = value_states.view(hidden_shape).transpose(1, 2) + query_states = query_states.reshape(hidden_shape).transpose(1, 2) + key_states = key_states.reshape(hidden_shape).transpose(1, 2) + value_states = value_states.reshape(hidden_shape).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index ae8c91f29d4d..83e125c07c15 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -362,15 +362,6 @@ def test_sdpa_can_dispatch_composite_models(self): if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: raise ValueError("The eager model should not have SDPA attention layers") - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - @require_torch class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase): From f651cd0d40088dfc64b6a13ae119b4f9d86f78b9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 19:15:38 +0100 Subject: [PATCH 89/97] fix shape with encoder --- .../modeling_decision_transformer.py | 12 ++++++------ src/transformers/models/gpt2/modeling_gpt2.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index f399ff447920..113de8a0aabc 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -282,12 +282,12 @@ def forward( else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + shape_q = (*query_states.shape[:-1], -1, self.head_dim) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.reshape(hidden_shape).transpose(1, 2) - key_states = key_states.reshape(hidden_shape).transpose(1, 2) - value_states = value_states.reshape(hidden_shape).transpose(1, 2) + query_states = query_states.reshape(shape_q).transpose(1, 2) + key_states = key_states.reshape(shape_kv).transpose(1, 2) + value_states = value_states.reshape(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past @@ -327,7 +327,7 @@ def forward( **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index ad91eee9d9dc..c40836120964 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -297,12 +297,12 @@ def forward( else: query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + shape_q = (*query_states.shape[:-1], -1, self.head_dim) + shape_kv = (*key_states.shape[:-1], -1, self.head_dim) - query_states = query_states.reshape(hidden_shape).transpose(1, 2) - key_states = key_states.reshape(hidden_shape).transpose(1, 2) - value_states = value_states.reshape(hidden_shape).transpose(1, 2) + query_states = query_states.reshape(shape_q).transpose(1, 2) + key_states = key_states.reshape(shape_kv).transpose(1, 2) + value_states = value_states.reshape(shape_kv).transpose(1, 2) if layer_past is not None: past_key, past_value = layer_past @@ -342,7 +342,7 @@ def forward( **kwargs, ) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) From 98b7f974392a19475d5d21dbe0535438dd5fbbfc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 19:53:53 +0100 Subject: [PATCH 90/97] should be the last one --- src/transformers/integrations/sdpa_attention.py | 5 ++++- .../decision_transformer/modeling_decision_transformer.py | 4 ++++ src/transformers/models/gpt2/modeling_gpt2.py | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 7baa0da1b7f3..265260c9b79e 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -23,6 +23,7 @@ def sdpa_attention_forward( attention_mask: Optional[torch.Tensor], dropout: float = 0.0, scaling: Optional[float] = None, + is_causal: Optional[bool] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: if hasattr(module, "num_key_value_groups"): @@ -37,7 +38,9 @@ def sdpa_attention_forward( key = key.contiguous() value = value.contiguous() - is_causal = causal_mask is None and query.shape[2] > 1 + if is_causal is None: + is_causal = causal_mask is None and query.shape[2] > 1 + attn_output = torch.nn.functional.scaled_dot_product_attention( query, key, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 113de8a0aabc..c8506445d399 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -299,6 +299,9 @@ def forward( else: present = None + is_cross_attention = encoder_hidden_states is not None + is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention + using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -324,6 +327,7 @@ def forward( attention_mask, head_mask=head_mask, dropout=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, **kwargs, ) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index c40836120964..1495b97d0e44 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -314,6 +314,9 @@ def forward( else: present = None + is_cross_attention = encoder_hidden_states is not None + is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention + using_eager = self.config._attn_implementation == "eager" attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -339,6 +342,7 @@ def forward( attention_mask, head_mask=head_mask, dropout=self.attn_dropout.p if self.training else 0.0, + is_causal=is_causal, **kwargs, ) From 88e2fe56bb3bc6b21ba971e8f346a6df6642921b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 17 Dec 2024 20:14:23 +0100 Subject: [PATCH 91/97] tests non model --- src/transformers/modeling_utils.py | 7 +----- tests/utils/test_modeling_utils.py | 35 ------------------------------ 2 files changed, 1 insertion(+), 41 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ed1c497553fb..ad9d251b7a92 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1488,12 +1488,7 @@ def _autoset_attn_implementation( message += ( ', `"attn_implementation=flex_attention"` (implementation using torch\'s flex_attention)' ) - if cls.model_type + "_" + config._attn_implementation in ALL_ATTENTION_FUNCTIONS: - config._attn_implementation_internal = cls.model_type + "_" + config._attn_implementation - if config._attn_implementation in ALL_ATTENTION_FUNCTIONS: - pass - else: - raise ValueError(message + ".") + raise ValueError(message + ".") # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 458ddeee5ff8..5a7119566c07 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -563,32 +563,17 @@ def test_model_from_pretrained_attn_implementation(self): if is_flash_attn_2_available(): attn_implementation_available.append("flash_attention_2") - mistral_attention_classes = { - "eager": "MistralAttention", - "sdpa": "MistralSdpaAttention", - "flash_attention_2": "MistralFlashAttention2", - } for requested_attn_implementation in attn_implementation_available: model = AutoModelForCausalLM.from_pretrained( TINY_MISTRAL, attn_implementation=requested_attn_implementation ) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) config = AutoConfig.from_pretrained(TINY_MISTRAL) model = AutoModelForCausalLM.from_pretrained( TINY_MISTRAL, config=config, attn_implementation=requested_attn_implementation ) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) def test_model_from_config_attn_implementation(self): # test that the model can be instantiated with attn_implementation of either @@ -602,11 +587,6 @@ def test_model_from_config_attn_implementation(self): if is_flash_attn_2_available(): attn_implementation_available.append("flash_attention_2") - mistral_attention_classes = { - "eager": "MistralAttention", - "sdpa": "MistralSdpaAttention", - "flash_attention_2": "MistralFlashAttention2", - } for requested_attn_implementation in attn_implementation_available: config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) # Ensure the config was set correctly @@ -614,11 +594,6 @@ def test_model_from_config_attn_implementation(self): self.assertEqual(config._attn_implementation_internal, requested_attn_implementation) model = AutoModelForCausalLM.from_config(config) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) config = AutoConfig.from_pretrained(TINY_MISTRAL) # When the config is not set, the default is "eager" @@ -626,11 +601,6 @@ def test_model_from_config_attn_implementation(self): self.assertEqual(config._attn_implementation_internal, None) model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) # Set a nonsense attn_implementation in the config, which should be overridden by the explicit argument config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation="foo-bar-baz") @@ -638,11 +608,6 @@ def test_model_from_config_attn_implementation(self): self.assertEqual(config._attn_implementation_internal, "foo-bar-baz") model = AutoModelForCausalLM.from_config(config=config, attn_implementation=requested_attn_implementation) self.assertEqual(model.config._attn_implementation, requested_attn_implementation) - for module in model.modules(): - if "Attention" in module.__class__.__name__: - self.assertEqual( - module.__class__.__name__, mistral_attention_classes[requested_attn_implementation] - ) def test_torch_dtype_byte_sizes(self): torch_dtypes_and_bytes = [ From 5a3bdc44ce395015b2358edd7bb6bb0552c5b317 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 18 Dec 2024 13:40:49 +0100 Subject: [PATCH 92/97] most comments --- .../modular-transformers/modeling_dummy.py | 5 +- .../modeling_multimodal1.py | 5 +- .../modeling_my_new_model2.py | 5 +- .../modular-transformers/modeling_super.py | 5 +- .../integrations/flash_attention.py | 10 ++- src/transformers/models/aria/modeling_aria.py | 5 +- .../modeling_decision_transformer.py | 3 + .../models/gemma/modeling_gemma.py | 13 ++-- .../models/gemma/modular_gemma.py | 33 ++++++++- src/transformers/models/glm/modeling_glm.py | 5 +- src/transformers/models/gpt2/modeling_gpt2.py | 8 +-- .../models/granite/modeling_granite.py | 11 ++- .../models/granite/modular_granite.py | 6 +- .../models/llama/modeling_llama.py | 5 +- .../models/mistral/modeling_mistral.py | 13 ++-- .../models/mistral/modular_mistral.py | 72 ++++++++++++++++++- .../models/mixtral/modeling_mixtral.py | 13 ++-- src/transformers/models/olmo/modeling_olmo.py | 5 +- .../models/olmo2/modeling_olmo2.py | 5 +- .../models/olmo2/modular_olmo2.py | 5 +- src/transformers/models/phi/modeling_phi.py | 9 +-- src/transformers/models/phi/modular_phi.py | 4 +- .../models/qwen2/modeling_qwen2.py | 7 +- .../models/qwen2/modular_qwen2.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 6 +- .../models/starcoder2/modeling_starcoder2.py | 15 ++-- .../models/starcoder2/modular_starcoder2.py | 20 +++--- 27 files changed, 170 insertions(+), 125 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b33ce10ae9e4..3e0aa6e9b2ad 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -177,13 +177,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 2ce2036e5645..c4f90a5cbada 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -177,13 +177,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 21c55efc4dbc..b8d5b5eb9100 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -177,13 +177,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index b0a863efd712..42d8108ee72a 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -177,13 +177,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index 2e48902a4ba6..1be223f8b079 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -19,7 +19,6 @@ def flash_attention_forward( scaling: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - target_dtype: torch.dtype = torch.float16, **kwargs, ) -> Tuple[torch.Tensor, None]: # This is before the transpose @@ -30,11 +29,10 @@ def flash_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) - input_dtype = query.dtype - if input_dtype == torch.float32: - query = query.to(target_dtype) - key = key.to(target_dtype) - value = value.to(target_dtype) + if query.dtype == torch.float32: + query = query.to(torch.float16) + key = key.to(torch.float16) + value = value.to(torch.float16) attn_output = _flash_attention_forward( query, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 0af8d2888a9d..6481d6f3c434 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -482,13 +482,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index c8506445d399..60fea55d87be 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -312,6 +312,9 @@ def forward( 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: + # Attention functions are consistent with previous equivalent attention classes, however they do not support some options + # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but + # not necessarily to eager (if mentionned options are provided). attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if using_eager and self.reorder_and_upcast_attn: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index df48af4e007f..e2ea12b03fe4 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -208,13 +208,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -820,16 +817,16 @@ def forward( ```python >>> from transformers import AutoTokenizer, GemmaForCausalLM - >>> model = GemmaForCausalLM.from_pretrained("meta-gemma/Gemma-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-gemma/Gemma-2-7b-hf") + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> prompt = "What is your favorite condiment?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + "What is your favorite condiment?" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 2c4c26374204..29b6f8a19461 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -466,7 +466,38 @@ def forward( class GemmaForCausalLM(LlamaForCausalLM): - pass + def forward(**super_kwargs): + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + return super().forward(**super_kwargs) class GemmaForSequenceClassification(LlamaForSequenceClassification): diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index caab4ce0a778..95ad0d971995 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -92,13 +92,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 1495b97d0e44..ad53c7804ebe 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -43,7 +43,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, logging, replace_return_docstrings, ) @@ -51,10 +50,6 @@ from .configuration_gpt2 import GPT2Config -if is_flash_attn_2_available(): - pass - - logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "openai-community/gpt2" @@ -327,6 +322,9 @@ def forward( 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: + # Attention functions are consistent with previous equivalent attention classes, however they do not support some options + # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but + # not necessarily to eager (if mentionned options are provided). attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] if using_eager and self.reorder_and_upcast_attn: diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 5c2f3d85ecb9..53357dbb2b56 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -99,13 +99,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -302,7 +299,7 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama outputs = (hidden_states,) @@ -566,7 +563,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds * self.embedding_multiplier + inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -869,7 +866,7 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - logits = logits / self.config.logits_scaling + logits = logits / self.config.logits_scaling # main diff with Llama loss = None if labels is not None: diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 8f683e9a46ec..e8b42e96fc33 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -99,7 +99,7 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama outputs = (hidden_states,) @@ -147,7 +147,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds * self.embedding_multiplier + inputs_embeds = inputs_embeds * self.embedding_multiplier # main diff with Llama if use_cache and past_key_values is None: past_key_values = DynamicCache() @@ -267,7 +267,7 @@ def forward( hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - logits = logits / self.config.logits_scaling + logits = logits / self.config.logits_scaling # main diff with Llama loss = None if labels is not None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index de67315f5ff8..5be33c26414c 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -212,13 +212,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a2c9a5f635f1..fdad30544b00 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -109,13 +109,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -191,7 +188,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama **kwargs, ) @@ -1075,12 +1072,12 @@ def forward( MISTRAL_START_DOCSTRING, ) class MistralForQuestionAnswering(MistralPreTrainedModel): - base_model_prefix = "transformer" + base_model_prefix = "model" def __init__(self, config): super().__init__(config) - self.transformer = MistralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MistralModel(config) # diff with Llama: transformer->model # Initialize weights and apply final processing self.post_init() @@ -1118,7 +1115,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 855c87c363ee..5196c0415f11 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -7,6 +7,7 @@ from ...cache_utils import Cache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import QuestionAnsweringModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging @@ -87,7 +88,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama **kwargs, ) @@ -261,4 +262,69 @@ class MistralForSequenceClassification(LlamaForSequenceClassification): class MistralForQuestionAnswering(LlamaForQuestionAnswering): - pass + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) # diff with Llama: transformer->model + del self.transformer + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + loss = None + if start_positions is not None and end_positions is not None: + loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return QuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 5b35cbaf61ce..66193b8c5771 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -222,13 +222,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -304,7 +301,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama **kwargs, ) @@ -1314,12 +1311,12 @@ def forward( MIXTRAL_START_DOCSTRING, ) class MixtralForQuestionAnswering(MixtralPreTrainedModel): - base_model_prefix = "transformer" + base_model_prefix = "model" def __init__(self, config): super().__init__(config) - self.transformer = MixtralModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = MixtralModel(config) # diff with Llama: transformer->model # Initialize weights and apply final processing self.post_init() @@ -1357,7 +1354,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 09d138f7056b..20de37010163 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -115,13 +115,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 4f33223f025f..af390c06314a 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -104,13 +104,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 6ff6e52912b3..66e8815c3fbb 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -5,7 +5,7 @@ from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import is_flash_attn_2_available, logging +from ...utils import logging from ..llama.modeling_llama import LlamaRMSNorm, eager_attention_forward from ..olmo.configuration_olmo import OlmoConfig from ..olmo.modeling_olmo import ( @@ -17,9 +17,6 @@ ) -if is_flash_attn_2_available(): - pass - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c4786a980b5a..477896decd53 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -92,13 +92,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -542,7 +539,7 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - inputs_embeds = self.embed_dropout(inputs_embeds) + inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -586,7 +583,7 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.final_layernorm(hidden_states) # diff with Llama # add hidden states from the last decoder layer if output_hidden_states: diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index 05676a0deb54..a47bbb4d1b79 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -221,7 +221,7 @@ def forward( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - inputs_embeds = self.embed_dropout(inputs_embeds) + inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers @@ -265,7 +265,7 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) - hidden_states = self.final_layernorm(hidden_states) + hidden_states = self.final_layernorm(hidden_states) # diff with Llama # add hidden states from the last decoder layer if output_hidden_states: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 2815df8f8f59..aa264ac1a649 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -109,13 +109,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -199,7 +196,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=sliding_window, + sliding_window=sliding_window, # main diff with Llama **kwargs, ) diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 46c8b0f8c658..055d53017026 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -93,7 +93,7 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=sliding_window, + sliding_window=sliding_window, # main diff with Llama **kwargs, ) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 66d6ce2cd8d6..b0c65ff32a15 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1556,12 +1556,12 @@ def forward( ) # Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Qwen2Moe, MISTRAL->QWEN2MOE class Qwen2MoeForQuestionAnswering(Qwen2MoePreTrainedModel): - base_model_prefix = "transformer" + base_model_prefix = "model" def __init__(self, config): super().__init__(config) - self.transformer = Qwen2MoeModel(config) self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.model = Qwen2MoeModel(config) # diff with Llama: transformer->model # Initialize weights and apply final processing self.post_init() @@ -1599,7 +1599,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.transformer( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index be09aade9de7..86a337019ba9 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -128,13 +128,10 @@ def eager_attention_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], + scaling: float, dropout: float = 0.0, - scaling: Optional[float] = None, **kwargs, ): - if scaling is None: - scaling = module.head_dim**-0.5 - key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -211,13 +208,15 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) + attn_output = nn.functional.dropout( + attn_output, p=self.residual_dropout, training=self.training + ) # diff with Llama return attn_output, attn_weights @@ -547,7 +546,9 @@ def forward( ) hidden_states = inputs_embeds - hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.embedding_dropout, training=self.training + ) # main diff with Llama # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 8fc041dd422a..2a4d830306ba 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -33,11 +33,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack -from ...utils import ( - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, -) +from ...utils import add_start_docstrings_to_model_forward, logging from ..mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, @@ -51,10 +47,6 @@ from .configuration_starcoder2 import Starcoder2Config -if is_flash_attn_2_available(): - pass - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "Starcoder2Config" @@ -129,13 +121,15 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, - sliding_window=getattr(self.config, "sliding_window", None), + sliding_window=getattr(self.config, "sliding_window", None), # diff with Llama **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) + attn_output = nn.functional.dropout( + attn_output, p=self.residual_dropout, training=self.training + ) # diff with Llama return attn_output, attn_weights @@ -207,7 +201,9 @@ def forward( ) hidden_states = inputs_embeds - hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) + hidden_states = nn.functional.dropout( + hidden_states, p=self.embedding_dropout, training=self.training + ) # main diff with Llama # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) From f3923b6e4434ee0fdee0bfa42d7762019bbccae0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 18 Dec 2024 13:57:10 +0100 Subject: [PATCH 93/97] small oupsi --- src/transformers/models/mistral/modeling_mistral.py | 4 ++-- src/transformers/models/mistral/modular_mistral.py | 6 ++++++ src/transformers/models/mixtral/modeling_mixtral.py | 4 ++-- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 4 ++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index fdad30544b00..7b2c95a650dd 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1083,10 +1083,10 @@ def __init__(self, config): self.post_init() def get_input_embeddings(self): - return self.transformer.embed_tokens + return self.model.embed_tokens def set_input_embeddings(self, value): - self.transformer.embed_tokens = value + self.model.embed_tokens = value @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 5196c0415f11..075323327e86 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -269,6 +269,12 @@ def __init__(self, config): self.model = MistralModel(config) # diff with Llama: transformer->model del self.transformer + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 66193b8c5771..84ed327d9be9 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1322,10 +1322,10 @@ def __init__(self, config): self.post_init() def get_input_embeddings(self): - return self.transformer.embed_tokens + return self.model.embed_tokens def set_input_embeddings(self, value): - self.transformer.embed_tokens = value + self.model.embed_tokens = value @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) def forward( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index b0c65ff32a15..1ce41509a5c0 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1567,10 +1567,10 @@ def __init__(self, config): self.post_init() def get_input_embeddings(self): - return self.transformer.embed_tokens + return self.model.embed_tokens def set_input_embeddings(self, value): - self.transformer.embed_tokens = value + self.model.embed_tokens = value @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) def forward( From a6a2ff9e8ea7cc56c37173e4b1996e6201fc2038 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 18 Dec 2024 15:13:57 +0100 Subject: [PATCH 94/97] be more explicit in modulars --- .../models/gemma2/modular_gemma2.py | 6 + .../models/granite/modeling_granite.py | 1 - .../models/granite/modular_granite.py | 9 +- .../models/mistral/modeling_mistral.py | 106 +++++++++--------- .../models/mistral/modular_mistral.py | 14 +++ .../models/starcoder2/modeling_starcoder2.py | 2 - 6 files changed, 79 insertions(+), 59 deletions(-) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index f9ec305b2272..48b12411361a 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -360,6 +360,12 @@ def forward( class Gemma2Model(GemmaModel): + def __init__(self, config: Gemma2Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 53357dbb2b56..2e045e149d95 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -235,7 +235,6 @@ class GraniteDecoderLayer(nn.Module): def __init__(self, config: GraniteConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) self.mlp = GraniteMLP(config) diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index e8b42e96fc33..698280085f18 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -17,6 +17,7 @@ import torch import torch.utils.checkpoint +from torch import nn from ...cache_utils import Cache, DynamicCache from ...modeling_flash_attention_utils import FlashAttentionKwargs @@ -34,14 +35,15 @@ class GraniteAttention(LlamaAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: GraniteConfig, layer_idx: Optional[int] = None): - super().__init__() + super().__init__(config, layer_idx) self.scaling = config.attention_multiplier class GraniteDecoderLayer(LlamaDecoderLayer): def __init__(self, config: GraniteConfig, layer_idx: int): - super().__init__() + super().__init__(config, layer_idx) self.residual_multiplier = config.residual_multiplier + self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) def forward( self, @@ -113,6 +115,9 @@ class GraniteModel(LlamaModel): def __init__(self, config: GraniteConfig): super().__init__(config) self.embedding_multiplier = config.embedding_multiplier + self.layers = nn.ModuleList( + [GraniteDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) def forward( self, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 7b2c95a650dd..90c38895b428 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -217,6 +217,58 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + class MistralRotaryEmbedding(nn.Module): def __init__( self, @@ -282,60 +334,6 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class MistralDecoderLayer(nn.Module): - def __init__(self, config: MistralConfig, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) - - self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) - - return outputs - - MISTRAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 075323327e86..362233a21b70 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -13,6 +13,7 @@ from ...utils import logging from ..llama.modeling_llama import ( LlamaAttention, + LlamaDecoderLayer, LlamaForCausalLM, LlamaForQuestionAnswering, LlamaForSequenceClassification, @@ -97,7 +98,20 @@ def forward( return attn_output, attn_weights +class MistralDecoderLayer(LlamaDecoderLayer): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = MistralAttention(config=config, layer_idx=layer_idx) + self.mlp = MistralMLP(config) + + class MistralModel(LlamaModel): + def __init__(self, config: MistralConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + def _update_causal_mask( self, attention_mask: torch.Tensor, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 86a337019ba9..3b4fdbcb81cc 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -225,9 +225,7 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) - self.mlp = Starcoder2MLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) From aeea33bd8b5e23633ed8c924fd63f467e17f1d34 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 18 Dec 2024 15:26:30 +0100 Subject: [PATCH 95/97] more explicit modulars --- src/transformers/models/mixtral/modular_mixtral.py | 7 +++++++ src/transformers/models/olmo/modeling_olmo.py | 1 - src/transformers/models/olmo/modular_olmo.py | 4 ++++ src/transformers/models/olmo2/modeling_olmo2.py | 1 - src/transformers/models/olmo2/modular_olmo2.py | 5 +++++ src/transformers/models/phi/modular_phi.py | 3 +++ src/transformers/models/qwen2/modeling_qwen2.py | 2 -- src/transformers/models/qwen2/modular_qwen2.py | 2 ++ src/transformers/models/starcoder2/modular_starcoder2.py | 7 ++++++- 9 files changed, 27 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 2c209bcd5455..a6069f69b334 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -314,6 +314,12 @@ def forward( class MixtralModel(MistralModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + def forward( self, input_ids: torch.LongTensor = None, @@ -441,6 +447,7 @@ class MixtralForCausalLM(MistralForCausalLM): def __init__(self, config): super().__init__(config) + self.model = MixtralModel(config) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 20de37010163..11d3d99f4f72 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -224,7 +224,6 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) self.mlp = OlmoMLP(config) diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 435ad5e72e61..2a43e6f9c75d 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -110,11 +110,15 @@ def __init__(self, config: OlmoConfig, layer_idx: int): super().__init__(config, layer_idx) self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + self.self_attn = OlmoAttention(config=config, layer_idx=layer_idx) class OlmoModel(LlamaModel): def __init__(self, config: OlmoConfig): super().__init__(config) + self.layers = nn.ModuleList( + [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = OlmoLayerNorm(config.hidden_size) diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index af390c06314a..49ae798e7f11 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -226,7 +226,6 @@ class Olmo2DecoderLayer(nn.Module): def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) self.mlp = Olmo2MLP(config) diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 66e8815c3fbb..5f1191708044 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -1,6 +1,7 @@ from typing import Callable, Optional, Tuple import torch +from torch import nn from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -227,6 +228,7 @@ def __init__(self, config: Olmo2Config, layer_idx: int): super().__init__(config, layer_idx=layer_idx) self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) del self.input_layernorm def forward( @@ -277,6 +279,9 @@ class Olmo2Model(OlmoModel): def __init__(self, config: Olmo2Config): super().__init__(config) self.norm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers = nn.ModuleList( + [Olmo2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) # The heads now only need to redefine the model inside to the correct `RobertaModel` diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index a47bbb4d1b79..0faa4629f1a7 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -168,6 +168,9 @@ def forward( class PhiModel(LlamaModel): def __init__(self, config: PhiConfig): super().__init__(config) + self.layers = nn.ModuleList( + [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.embed_dropout = nn.Dropout(config.embd_pdrop) self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) del self.norm diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index aa264ac1a649..36fb1ddf1390 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -229,9 +229,7 @@ class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) - self.mlp = Qwen2MLP(config) self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 055d53017026..718abd01090c 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -105,6 +105,8 @@ def forward( class Qwen2DecoderLayer(LlamaDecoderLayer): def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() + self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = Qwen2MLP(config) if config.sliding_window and config._attn_implementation != "flash_attention_2": logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 2a4d830306ba..32d64cd167ba 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -137,6 +137,8 @@ def forward( class Starcoder2DecoderLayer(MistralDecoderLayer): def __init__(self, config: Starcoder2Config, layer_idx: int): super().__init__(self) + self.self_attn = Starcoder2Attention(config=config, layer_idx=layer_idx) + self.mlp = Starcoder2MLP(config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -147,8 +149,11 @@ def __init__(self, config: Starcoder2Config, layer_idx: int): class Starcoder2Model(MistralModel): def __init__(self, config: Starcoder2Config): super().__init__(config) - self.embedding_dropout = config.embedding_dropout + self.layers = nn.ModuleList( + [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.embedding_dropout = config.embedding_dropout @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING) def forward( From ec3bef3d4d216248a8db545e825461ad35bb0db1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 18 Dec 2024 15:42:34 +0100 Subject: [PATCH 96/97] CIs! it works locally From fc74e3976d090b655a8278ca1b70f784ff8b7938 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 18 Dec 2024 16:16:37 +0100 Subject: [PATCH 97/97] add kwargs to _flash_attention_forward --- src/transformers/modeling_flash_attention_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index ba034772e03c..6adda0036cc0 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -247,6 +247,7 @@ def _flash_attention_forward( max_length_q: Optional[int] = None, max_length_k: Optional[int] = None, target_dtype: Optional[torch.dtype] = None, + **kwargs, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token