Skip to content

Commit

Permalink
Add "blocked"/"flash" cross entropy (#790)
Browse files Browse the repository at this point in the history
to mitigate large tokenizers limiting blocksize (e.g. llama3) 

imposes a kind of not ideal refactor on LMModel, but it's not the worst.

FYI @Helw150

---------

Co-authored-by: Ivan Zhou <[email protected]>
Co-authored-by: Abhinav Garg <[email protected]>
  • Loading branch information
3 people authored Nov 6, 2024
1 parent 1c43256 commit a7e42ec
Show file tree
Hide file tree
Showing 16 changed files with 880 additions and 44 deletions.
32 changes: 32 additions & 0 deletions config/llama3_small_fast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/"
tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
model:
type: llama
hidden_dim: 768
intermediate_dim: 2048
num_heads: 12
num_kv_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
trainer:
tracker:
- type: wandb
project: "levanter"
tags: [ "openwebtext", "llama", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
2 changes: 1 addition & 1 deletion config/llama_7b_with_dclm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 70000 # 280B / 4M
num_train_steps: 480000 # 2T / 4M
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
Expand Down
3 changes: 2 additions & 1 deletion src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):

# ray doesn't merge the runtime envs properly, so we have to do it ourselves
# we need to do a deep merge
runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE)
sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None]
runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE)

remote_fn = remote_fn.options(
runtime_env=runtime_env,
Expand Down
7 changes: 4 additions & 3 deletions src/levanter/models/backpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def init(Vocab: Axis, config: BackpackConfig, *, key):
)

@named_call
def __call__(
def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
k_embed, k_transformer, k_senses, k_sa = haliax.jax_utils.maybe_rng_split(key, 4)
Expand All @@ -428,9 +428,10 @@ def __call__(
scale = self.config.Senses.size
hidden_states = hidden_states / scale

lm_logits = self.embeddings.unembed(hidden_states)
return hidden_states

return lm_logits
def get_lm_head(self) -> hax.NamedArray:
return self.embeddings.token_embeddings

def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None):
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
8 changes: 5 additions & 3 deletions src/levanter/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,17 @@ def vocab_size(self) -> int:
def Vocab(self) -> Axis:
return self.embeddings.Vocab

def get_lm_head(self) -> hax.NamedArray:
return self.embeddings.token_embeddings.weight

@classmethod
def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel":
k_t, k_emb = jrandom.split(key, 2)
transformer = GemmaTransformer.init(config, key=k_t)
embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb)
return GemmaLMHeadModel(transformer, embeddings)

def __call__(
def activations(
self,
input_ids: NamedArray,
attn_mask: Optional[Union[NamedArray, AttentionMask]] = None,
Expand All @@ -363,8 +366,7 @@ def __call__(
"""
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=key)
lm_logits = self.embeddings.unembed(x)
return lm_logits
return x

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]":
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
8 changes: 5 additions & 3 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,17 @@ def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel":

return Gpt2LMHeadModel(transformer, embeddings)

def __call__(
def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
k_embed, k_transformer = haliax.jax_utils.maybe_rng_split(key, 2)
x = self.embeddings.embed(input_ids, key=k_embed)
x = self.transformer(x, attn_mask, key=k_transformer)
lm_logits = self.embeddings.unembed(x)

return lm_logits
return x

def get_lm_head(self) -> hax.NamedArray:
return self.embeddings.token_embeddings.weight

def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None) -> "Gpt2LMHeadModel":
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
25 changes: 25 additions & 0 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ def __call__(
lm_logits = self.embeddings.unembed(x)
return lm_logits

def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the activations for the next token in a sequence.
Args:
input_ids: token IDs with shape {Pos}
attn_mask: attention mask with shape {Pos, KeyPos}
key: PRNGKey for random number generation
Returns:
NamedArray: activations with shape {Pos, Embed}
"""
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=key)

return x

def get_lm_head(self) -> hax.NamedArray:
if self.lm_head is None:
return self.embeddings.token_embeddings.weight
else:
return self.lm_head.weight

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
new_Vocab = self.Vocab.resize(new_size)
k1, k2 = maybe_rng_split(key, 2)
Expand Down
65 changes: 60 additions & 5 deletions src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def KeyPos(self) -> Axis:
def Pos(self) -> Axis:
pass

@property
@abc.abstractmethod
def Embed(self) -> Axis:
pass

cross_entropy_block_size: Optional[int] = 64000
"""
The block size for computing cross-entropy loss. This is the number of tokens that are processed together
in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large
value because it usually faster to compute the loss in larger blocks.
"""

def flops_per_token(self, vocab_size: int) -> Optional[float]:
return None

Expand Down Expand Up @@ -94,17 +106,58 @@ def Pos(self) -> Axis:
def KeyPos(self) -> Axis:
return self.config.KeyPos

@property
def Embed(self) -> Axis:
return self.config.Embed

@classmethod
@abc.abstractmethod
def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[LmConfigT]":
pass

@abc.abstractmethod
def __call__(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the logits for the next token in a sequence.
Args:
input_ids: token IDs with shape [..., Pos]
attn_mask: attention mask with shape [..., Pos, KeyPos]
key: PRNGKey for random number generation
Returns:
NamedArray: logits with shape [..., Pos, Vocab]
"""
x = self.activations(input_ids, attn_mask, key=key)
lm_logits = hax.dot(x, self.get_lm_head(), axis=self.Embed)

return lm_logits

@abc.abstractmethod
def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the activations for the next token in a sequence.
Args:
input_ids: token IDs with shape {Pos}
attn_mask: attention mask with shape {Pos, KeyPos}
key: PRNGKey for random number generation
Returns:
NamedArray: activations with shape {Pos, Embed}
"""
pass

@abc.abstractmethod
def get_lm_head(self) -> hax.NamedArray:
"""
The language modeling head of the model. Should have shape {Embed, Vocab}.
"""
raise NotImplementedError("get_lm_head not implemented")

@abc.abstractmethod
def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]":
"""
Expand Down Expand Up @@ -133,19 +186,21 @@ def compute_next_token_loss(
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
logits = model(example.tokens, example.attn_mask, key=key)
if loss_dtype is not None:
logits = logits.astype(loss_dtype)
activations = model.activations(example.tokens, example.attn_mask, key=key)

loss = next_token_loss(
model.Pos,
model.Embed,
model.Vocab,
logits,
activations,
model.get_lm_head(),
example.tokens,
loss_mask=example.loss_mask,
reduction=reduction,
reduction_axis=reduction_axis,
logsumexp_weight=logsumexp_weight,
dtype=loss_dtype,
block_size=model.config.cross_entropy_block_size,
)

return loss
Loading

0 comments on commit a7e42ec

Please sign in to comment.