Skip to content

Commit

Permalink
Can do training without FA
Browse files Browse the repository at this point in the history
  • Loading branch information
kylematoba committed Nov 19, 2024
1 parent a821f1e commit d6d89ae
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/config_validation_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,5 @@ tokens:
limit_val_batches: 5
micro_batch_size: 2
sequence_length: 256
train_steps: 15
train_steps: 200
val_check_interval: 2
56 changes: 45 additions & 11 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator
from nanotron.utils import checkpoint_method

# use_flash_attn = False
use_flash_attn = True

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -193,8 +196,6 @@ def forward(
key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim]
value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim]
):
use_flash_attn = True
# use_flash_attn = False
if use_flash_attn:
from flash_attn.flash_attn_interface import flash_attn_func

Expand Down Expand Up @@ -252,6 +253,36 @@ def pad_to_right(tensor, mask, new_tensor=None):
return new_tensor, right_padded_mask


class RotaryEmbeddingKyleLikeFA(torch.nn.Module):
"""
Has the same function signature as FA, for interleaved=True and separate q, kv.
seqlen_offset = 0
Does not operate inplace, but that's fine for how it's used in Nanotron.
"""
def __init__(self, dim: int, base: float):
super().__init__()
self.dim = dim
self.base = float(base)

self.max_seq_len = None
self.rpe = None

def forward(self, q, kv):
bs, q_len, n_heads, _ = q.shape
assert self.dim == _

assert (bs, q_len, 2, n_heads, self.dim) == kv.shape

if (self.rpe is None) or (self.max_seq_len != q_len):
self.max_seq_len = q_len
self.rpe = torchtune.modules.RotaryPositionalEmbeddings(dim=self.dim,
max_seq_len=self.max_seq_len,
base=self.base).to(q.device)
q_out = self.rpe(q)
kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2)
return q_out, kv_out


class CausalSelfAttention(nn.Module, AttachableStore):
def __init__(
self,
Expand All @@ -260,7 +291,6 @@ def __init__(
tp_pg: dist.ProcessGroup,
layer_idx: int,
):
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding

super().__init__()
# Tensor parallel considerations: We split tensors along head dimension
Expand Down Expand Up @@ -321,11 +351,16 @@ def __init__(
end=config.max_position_embeddings,
theta=config.rope_theta,
)

# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta
)
if use_flash_attn:
from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta
)
else:
import torchtune.modules
assert config.rope_interleaved, "this case not yet tested"
self.flash_rotary_embedding = RotaryEmbeddingKyleLikeFA(dim=self.d_qk, base=config.rope_theta)

self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
Expand All @@ -352,10 +387,6 @@ def forward(
sequence_mask, # [batch_size, seq_length]
):
from flash_attn import bert_padding
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)

qkv_states = self.qkv_proj(
hidden_states
Expand Down Expand Up @@ -408,6 +439,7 @@ def forward(
key_states = self.rotary_embedding(key_states, position_ids=position_ids)

if "key" not in store:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# First inference iteration (Prefill)
# TODO @nouamane: support custom masking
# assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
Expand Down Expand Up @@ -468,6 +500,8 @@ def forward(
pad_to_right(value_states, sequence_mask, new_tensor=v_cache)

else:
from flash_attn.flash_attn_interface import flash_attn_with_kvcache

# Pull pre-computed key/value states
# Subsequent inference iterations (q_length=1)
k_cache = store["key"]
Expand Down

0 comments on commit d6d89ae

Please sign in to comment.