diff --git a/examples/config_validation_tiny_llama.yaml b/examples/config_validation_tiny_llama.yaml index ef305f4b..8e7a290c 100644 --- a/examples/config_validation_tiny_llama.yaml +++ b/examples/config_validation_tiny_llama.yaml @@ -130,5 +130,5 @@ tokens: limit_val_batches: 5 micro_batch_size: 2 sequence_length: 256 - train_steps: 15 + train_steps: 200 val_check_interval: 2 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 1c796a1e..61ffbba8 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -44,6 +44,9 @@ from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator from nanotron.utils import checkpoint_method +# use_flash_attn = False +use_flash_attn = True + logger = logging.get_logger(__name__) @@ -193,8 +196,6 @@ def forward( key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - use_flash_attn = True - # use_flash_attn = False if use_flash_attn: from flash_attn.flash_attn_interface import flash_attn_func @@ -252,6 +253,36 @@ def pad_to_right(tensor, mask, new_tensor=None): return new_tensor, right_padded_mask +class RotaryEmbeddingKyleLikeFA(torch.nn.Module): + """ + Has the same function signature as FA, for interleaved=True and separate q, kv. + seqlen_offset = 0 + Does not operate inplace, but that's fine for how it's used in Nanotron. + """ + def __init__(self, dim: int, base: float): + super().__init__() + self.dim = dim + self.base = float(base) + + self.max_seq_len = None + self.rpe = None + + def forward(self, q, kv): + bs, q_len, n_heads, _ = q.shape + assert self.dim == _ + + assert (bs, q_len, 2, n_heads, self.dim) == kv.shape + + if (self.rpe is None) or (self.max_seq_len != q_len): + self.max_seq_len = q_len + self.rpe = torchtune.modules.RotaryPositionalEmbeddings(dim=self.dim, + max_seq_len=self.max_seq_len, + base=self.base).to(q.device) + q_out = self.rpe(q) + kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2) + return q_out, kv_out + + class CausalSelfAttention(nn.Module, AttachableStore): def __init__( self, @@ -260,7 +291,6 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding super().__init__() # Tensor parallel considerations: We split tensors along head dimension @@ -321,11 +351,16 @@ def __init__( end=config.max_position_embeddings, theta=config.rope_theta, ) - # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding( - dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta - ) + if use_flash_attn: + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta + ) + else: + import torchtune.modules + assert config.rope_interleaved, "this case not yet tested" + self.flash_rotary_embedding = RotaryEmbeddingKyleLikeFA(dim=self.d_qk, base=config.rope_theta) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -352,10 +387,6 @@ def forward( sequence_mask, # [batch_size, seq_length] ): from flash_attn import bert_padding - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_func, - flash_attn_with_kvcache, - ) qkv_states = self.qkv_proj( hidden_states @@ -408,6 +439,7 @@ def forward( key_states = self.rotary_embedding(key_states, position_ids=position_ids) if "key" not in store: + from flash_attn.flash_attn_interface import flash_attn_varlen_func # First inference iteration (Prefill) # TODO @nouamane: support custom masking # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted @@ -468,6 +500,8 @@ def forward( pad_to_right(value_states, sequence_mask, new_tensor=v_cache) else: + from flash_attn.flash_attn_interface import flash_attn_with_kvcache + # Pull pre-computed key/value states # Subsequent inference iterations (q_length=1) k_cache = store["key"]