Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "blocked"/"flash" cross entropy #790

Merged
merged 46 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
fa9fd25
Add llama fineweb yaml
Ivan-Zhou Jun 25, 2024
ac80e57
small modification
Ivan-Zhou Jun 25, 2024
8b9bd78
pre commit checks
Ivan-Zhou Jun 25, 2024
525f4a6
mypy
Ivan-Zhou Jun 25, 2024
92425ca
add fire
Ivan-Zhou Jun 25, 2024
1d9f6f8
add more html
Ivan-Zhou Jun 25, 2024
b1c0905
add more md urls
Ivan-Zhou Jun 25, 2024
89b6192
delete get_files_on_gcs.py
Ivan-Zhou Jun 26, 2024
0b34522
CC-MAIN-*/*/*_processed_md.jsonl.gz
Ivan-Zhou Jun 27, 2024
4aa2f2c
Adding configs related to DCLM
abhinavg4 Jul 18, 2024
dde9ed0
Adding configs related to DCLM
abhinavg4 Jul 19, 2024
b991e29
Adding Z loss
abhinavg4 Jul 19, 2024
bb674bb
pre commit changes
abhinavg4 Jul 19, 2024
6c99dfb
Adding z_loss as part of train_lm.py
abhinavg4 Jul 19, 2024
24469e7
Reverting changes to llama.py for z_loss
abhinavg4 Jul 19, 2024
e12c1b6
Merge remote-tracking branch 'origin/dclm' into dclm
dlwh Aug 20, 2024
c9ebc88
match specs in dclm
dlwh Aug 20, 2024
7727696
publish dev build
dlwh Aug 21, 2024
55e4d98
wip
dlwh Aug 21, 2024
de51236
fix imports and such
dlwh Aug 22, 2024
7863989
get default zone from gcloud config
dlwh Aug 22, 2024
a550bb5
factor out docker command, build
dlwh Aug 22, 2024
6341252
Merge remote-tracking branch 'origin/main' into dclm
dlwh Aug 22, 2024
e9ca517
Merge remote-tracking branch 'origin/main' into dclm
dlwh Aug 28, 2024
d674dd9
wip
dlwh Aug 29, 2024
06dc304
wip
dlwh Aug 29, 2024
f13cfde
bump equinox
dlwh Sep 5, 2024
8d3dfe0
wip
dlwh Sep 6, 2024
8ecb7ea
768
dlwh Sep 6, 2024
0ea3eb4
Merge remote-tracking branch 'origin/main' into dclm
dlwh Oct 14, 2024
2f53923
wip
dlwh Oct 21, 2024
9050258
wip
dlwh Oct 30, 2024
b15e5d3
Merge remote-tracking branch 'origin/main' into blocked_cross_entropy
dlwh Oct 30, 2024
d6a3ded
wip
dlwh Oct 30, 2024
2e25357
it works?!?
dlwh Nov 4, 2024
2390058
tuning. just about there
dlwh Nov 5, 2024
795fd08
wip
dlwh Nov 5, 2024
05afef0
Merge remote-tracking branch 'origin/main' into blocked_cross_entropy
dlwh Nov 6, 2024
0de1482
pre-commit
dlwh Nov 6, 2024
fc01b9e
remove stray files
dlwh Nov 6, 2024
d907aa4
implement lm_head
dlwh Nov 6, 2024
eee3ecd
implement lm_head
dlwh Nov 6, 2024
879d5c0
misc test fixes
dlwh Nov 6, 2024
6fe0fb8
Merge remote-tracking branch 'origin/blocked_cross_entropy' into bloc…
dlwh Nov 6, 2024
5cf5f17
increase tolerances
dlwh Nov 6, 2024
db08341
increase tolerances
dlwh Nov 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions config/llama3_small_fast.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
data:
train_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz"
validation_urls:
- "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz"
cache_dir: "gs://levanter-data/tokenized/openwebtext_llama3/"
tokenizer: "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"
model:
type: llama
hidden_dim: 768
intermediate_dim: 2048
num_heads: 12
num_kv_heads: 12
num_layers: 12
seq_len: 1024
gradient_checkpointing: true
trainer:
tracker:
- type: wandb
project: "levanter"
tags: [ "openwebtext", "llama", "itest"]

mp: p=f32,c=bfloat16
model_axis_size: 1
per_device_parallelism: -1

train_batch_size: 256
num_train_steps: 20000
optimizer:
learning_rate: 1E-3
weight_decay: 0.1
warmup: 0.01
2 changes: 1 addition & 1 deletion config/llama_7b_with_dclm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trainer:

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 70000 # 280B / 4M
num_train_steps: 480000 # 2T / 4M
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
Expand Down
3 changes: 2 additions & 1 deletion src/levanter/infra/ray_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def _redecorate_remote_fn_for_tpu(remote_fn, num_hosts, **runtime_env):

# ray doesn't merge the runtime envs properly, so we have to do it ourselves
# we need to do a deep merge
runtime_env = mergedeep.merge({}, runtime_env, remote_fn._runtime_env, strategy=mergedeep.Strategy.ADDITIVE)
sources = [e for e in [remote_fn._runtime_env, runtime_env] if e is not None]
runtime_env = mergedeep.merge({}, *sources, strategy=mergedeep.Strategy.ADDITIVE)

remote_fn = remote_fn.options(
runtime_env=runtime_env,
Expand Down
7 changes: 4 additions & 3 deletions src/levanter/models/backpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def init(Vocab: Axis, config: BackpackConfig, *, key):
)

@named_call
def __call__(
def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
k_embed, k_transformer, k_senses, k_sa = haliax.jax_utils.maybe_rng_split(key, 4)
Expand All @@ -428,9 +428,10 @@ def __call__(
scale = self.config.Senses.size
hidden_states = hidden_states / scale

lm_logits = self.embeddings.unembed(hidden_states)
return hidden_states

return lm_logits
def get_lm_head(self) -> hax.NamedArray:
return self.embeddings.token_embeddings

def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None):
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
8 changes: 5 additions & 3 deletions src/levanter/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,17 @@ def vocab_size(self) -> int:
def Vocab(self) -> Axis:
return self.embeddings.Vocab

def get_lm_head(self) -> hax.NamedArray:
return self.embeddings.token_embeddings.weight

@classmethod
def init(cls, Vocab: Axis, config: GemmaConfig, *, key) -> "GemmaLMHeadModel":
k_t, k_emb = jrandom.split(key, 2)
transformer = GemmaTransformer.init(config, key=k_t)
embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb)
return GemmaLMHeadModel(transformer, embeddings)

def __call__(
def activations(
self,
input_ids: NamedArray,
attn_mask: Optional[Union[NamedArray, AttentionMask]] = None,
Expand All @@ -363,8 +366,7 @@ def __call__(
"""
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=key)
lm_logits = self.embeddings.unembed(x)
return lm_logits
return x

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[GemmaConfig]":
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
8 changes: 5 additions & 3 deletions src/levanter/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,17 @@ def init(cls, Vocab: Axis, config: Gpt2Config, *, key) -> "Gpt2LMHeadModel":

return Gpt2LMHeadModel(transformer, embeddings)

def __call__(
def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
k_embed, k_transformer = haliax.jax_utils.maybe_rng_split(key, 2)
x = self.embeddings.embed(input_ids, key=k_embed)
x = self.transformer(x, attn_mask, key=k_transformer)
lm_logits = self.embeddings.unembed(x)

return lm_logits
return x

def get_lm_head(self) -> hax.NamedArray:
return self.embeddings.token_embeddings.weight

def resize_vocab(self, new_size: int, key: Optional[PRNGKeyArray] = None) -> "Gpt2LMHeadModel":
new_embeddings = self.embeddings.resize_embeddings(new_size, key=key)
Expand Down
25 changes: 25 additions & 0 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ def __call__(
lm_logits = self.embeddings.unembed(x)
return lm_logits

def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the activations for the next token in a sequence.
Args:
input_ids: token IDs with shape {Pos}
attn_mask: attention mask with shape {Pos, KeyPos}
key: PRNGKey for random number generation
Returns:
NamedArray: activations with shape {Pos, Embed}
"""
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=key)

return x

def get_lm_head(self) -> hax.NamedArray:
if self.lm_head is None:
return self.embeddings.token_embeddings.weight
else:
return self.lm_head.weight

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
new_Vocab = self.Vocab.resize(new_size)
k1, k2 = maybe_rng_split(key, 2)
Expand Down
65 changes: 60 additions & 5 deletions src/levanter/models/lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def KeyPos(self) -> Axis:
def Pos(self) -> Axis:
pass

@property
@abc.abstractmethod
def Embed(self) -> Axis:
pass

cross_entropy_block_size: Optional[int] = 64000
"""
The block size for computing cross-entropy loss. This is the number of tokens that are processed together
in a single block. This can be adjusted to fit within memory constraints. It's deliberately set to a large
value because it usually faster to compute the loss in larger blocks.
"""

def flops_per_token(self, vocab_size: int) -> Optional[float]:
return None

Expand Down Expand Up @@ -94,17 +106,58 @@ def Pos(self) -> Axis:
def KeyPos(self) -> Axis:
return self.config.KeyPos

@property
def Embed(self) -> Axis:
return self.config.Embed

@classmethod
@abc.abstractmethod
def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[LmConfigT]":
pass

@abc.abstractmethod
def __call__(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the logits for the next token in a sequence.
Args:
input_ids: token IDs with shape [..., Pos]
attn_mask: attention mask with shape [..., Pos, KeyPos]
key: PRNGKey for random number generation

Returns:
NamedArray: logits with shape [..., Pos, Vocab]

"""
x = self.activations(input_ids, attn_mask, key=key)
lm_logits = hax.dot(x, self.get_lm_head(), axis=self.Embed)

return lm_logits

@abc.abstractmethod
def activations(
self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None
) -> NamedArray:
"""
Compute the activations for the next token in a sequence.
Args:
input_ids: token IDs with shape {Pos}
attn_mask: attention mask with shape {Pos, KeyPos}
key: PRNGKey for random number generation

Returns:
NamedArray: activations with shape {Pos, Embed}

"""
pass

@abc.abstractmethod
def get_lm_head(self) -> hax.NamedArray:
"""
The language modeling head of the model. Should have shape {Embed, Vocab}.
"""
raise NotImplementedError("get_lm_head not implemented")

@abc.abstractmethod
def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]":
"""
Expand Down Expand Up @@ -133,19 +186,21 @@ def compute_next_token_loss(
across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not
reduced, and the result is a named array with axes (*batch axes, sequence_length).
"""
logits = model(example.tokens, example.attn_mask, key=key)
if loss_dtype is not None:
logits = logits.astype(loss_dtype)
activations = model.activations(example.tokens, example.attn_mask, key=key)

loss = next_token_loss(
model.Pos,
model.Embed,
model.Vocab,
logits,
activations,
model.get_lm_head(),
example.tokens,
loss_mask=example.loss_mask,
reduction=reduction,
reduction_axis=reduction_axis,
logsumexp_weight=logsumexp_weight,
dtype=loss_dtype,
block_size=model.config.cross_entropy_block_size,
)

return loss
Loading
Loading