Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Kron #843

Draft
wants to merge 66 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
86d6556
logs
Jan 1, 2024
8b0900f
try dataset path
WhenWen Jan 1, 2024
5bf8cd0
use the jax serialization manager for deser in an attempt to fix cras…
dlwh Oct 12, 2024
20a4568
cleaner data loader but ti doesn't help :(
dlwh Oct 13, 2024
2747705
ok this maybe fixed it?
dlwh Oct 14, 2024
538f0ed
cleanup
dlwh Oct 14, 2024
2c5ee4b
fix tests
dlwh Oct 14, 2024
ab543d6
fix what is probably the underlying problem
dlwh Oct 14, 2024
6395305
wip
dlwh Oct 14, 2024
b8f28f4
Merge branch 'main' of github.com:WhenWen/levanter-2024 into main
Dec 3, 2024
f98e376
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
e961c95
Implement MARS (tested) and Muon (have bug in saving), example config…
Dec 3, 2024
2b80af7
wip
dlwh Dec 3, 2024
87d7665
enough device puts and we're good
dlwh Dec 3, 2024
074d0ec
ok we're good
dlwh Dec 3, 2024
5668289
Merge remote-tracking branch 'origin/main' into WhenWen/main
dlwh Dec 3, 2024
d2d310e
Merge branch 'use_manager_deser' into WhenWen/main
dlwh Dec 3, 2024
722edaf
fix tree leaf stuff
dlwh Dec 3, 2024
5692611
add map_flattened_linear_layers use in muon
dlwh Dec 4, 2024
2f119bd
Merge remote-tracking branch 'origin/main' into muon
dlwh Dec 4, 2024
0f41ebb
adding kron file to optim
evanatyourservice Dec 4, 2024
fe3ecc9
testing 123
evanatyourservice Dec 5, 2024
37452c7
Update kron.py
evanatyourservice Dec 5, 2024
701956d
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
aac1cee
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
e44e7fa
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 5, 2024
476ba36
Update kron.py
evanatyourservice Dec 6, 2024
966b80e
expose precond lr and init
evanatyourservice Dec 7, 2024
6408d23
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Dec 7, 2024
91e29e7
Update kron.py
evanatyourservice Dec 10, 2024
6b9ce26
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 14, 2024
53efaa3
Update llama2_100M_kron_test.yaml
evanatyourservice Dec 14, 2024
311da92
Update README.md
evanatyourservice Dec 14, 2024
33c17f4
Update kron.py
evanatyourservice Dec 14, 2024
8da6e34
Update kron.py
evanatyourservice Dec 14, 2024
5607cec
Update kron.py
evanatyourservice Dec 14, 2024
2fb6c34
trust remote code
evanatyourservice Dec 15, 2024
336e1e1
settings defaults
evanatyourservice Dec 15, 2024
a5ff351
no key, deterministic, pass all into cond, more sharding
evanatyourservice Dec 15, 2024
3a06e1c
set key in state
evanatyourservice Dec 15, 2024
f7f2382
whoops
evanatyourservice Dec 15, 2024
07781e6
small fix
evanatyourservice Dec 15, 2024
9ef0869
Update kron.py
evanatyourservice Dec 15, 2024
ed50cce
Update kron.py
evanatyourservice Dec 15, 2024
1dc0f43
settings
evanatyourservice Dec 15, 2024
f1c1b38
small fix in init sharding
evanatyourservice Dec 16, 2024
7a6f501
trying repl only
evanatyourservice Dec 16, 2024
3473eed
Revert "trying repl only"
evanatyourservice Dec 16, 2024
7684702
trying while loop
evanatyourservice Dec 16, 2024
d284518
trying more simple psgd kron version
evanatyourservice Dec 19, 2024
0c920b0
Update kron.py
evanatyourservice Dec 19, 2024
2feff32
Merge branch 'stanford-crfm:main' into kron
evanatyourservice Dec 19, 2024
6a2e19f
trying simple version
evanatyourservice Dec 19, 2024
5108be0
take out unavailable args
evanatyourservice Dec 19, 2024
975a2d7
no extra args
evanatyourservice Dec 19, 2024
3fa70ab
trying this
evanatyourservice Dec 19, 2024
b62963e
Update kron.py
evanatyourservice Dec 19, 2024
5bafdcf
Revert "Update kron.py"
evanatyourservice Dec 19, 2024
c47c4c5
small fix
evanatyourservice Dec 19, 2024
ee747c0
settings
evanatyourservice Dec 19, 2024
59f2c10
small changes/moving to remote
evanatyourservice Dec 22, 2024
aa43e4f
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 22, 2024
25a2c20
simplified kron is working, need to test on larger pod
evanatyourservice Dec 22, 2024
88d49ed
Update kron.py
evanatyourservice Dec 23, 2024
4d630d8
get rid of norming and clipping in lieu of rms clip, retouches
evanatyourservice Dec 31, 2024
964cf19
Merge remote-tracking branch 'upstream/main' into kron
evanatyourservice Dec 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ If you're using a TPU, more complete documentation for setting that up is availa
As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset.

```bash
python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml
python -m levanter.main.train_lm --config_path config/llama2_100M_kron_test.yaml

# alternatively, if you didn't use -e and are in a different directory
python -m levanter.main.train_lm --config_path gpt2_nano
Expand Down
34 changes: 34 additions & 0 deletions config/llama2_100M_kron_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data:
id: openwebtext
model:
type: llama
seq_len: 4096
hidden_dim: 768
intermediate_dim: 3072
num_layers: 12
num_heads: 12
num_kv_heads: 12
trainer:
tracker:
project: "levanter"
tags: ["pile", "llama"]
mp: p=f32,c=bfloat16
model_axis_size: 1
checkpointer:
keep:
- every: 1000
save_interval: 30m


train_batch_size: 1024
per_device_parallelism: 32 # set for v3 TPU
per_device_eval_parallelism: 32 # set a larger batch size for eval
num_train_steps: 50001
optimizer:
learning_rate: 3E-4
weight_decay: 0.1
warmup: 2000
cooldown: 0.1
lr_schedule: constant
min_lr_ratio: 0.0
type: kron
34 changes: 34 additions & 0 deletions config/llama2_100M_mars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data: !include data/dclm_gpt_neo.yaml
model:
type: llama
seq_len: 4096
hidden_dim: 768
intermediate_dim: 3072
num_layers: 12
num_heads: 12
num_kv_heads: 12
trainer:
tracker:
project: "levanter"
tags: ["pile", "llama"]
mp: p=f32,c=bfloat16
model_axis_size: 1
checkpointer:
keep:
- every: 1000
save_interval: 30m


train_batch_size: 1024
per_device_parallelism: 4 # set for v3 TPU
per_device_eval_parallelism: 4 # set a larger batch size for eval
num_train_steps: 50001
optimizer:
learning_rate: 4E-3 # set low for fine-tuning
weight_decay: 0.1
min_lr_ratio: 0.0
warmup: 2000
cooldown: 0.4
lr_schedule: constant
gamma: 0.025
type: mars
34 changes: 34 additions & 0 deletions config/llama2_100M_muon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
data: !include data/dclm_gpt_neo.yaml
model:
type: llama
seq_len: 4096
hidden_dim: 768
intermediate_dim: 3072
num_layers: 12
num_heads: 12
num_kv_heads: 12
trainer:
tracker:
project: "levanter"
tags: ["pile", "llama"]
mp: p=f32,c=bfloat16
model_axis_size: 1
checkpointer:
keep:
- every: 1000
save_interval: 30m


train_batch_size: 1024
per_device_parallelism: 4 # set for v3 TPU
per_device_eval_parallelism: 4 # set a larger batch size for eval
num_train_steps: 50001
optimizer:
learning_rate: 2E-2 # set low for fine-tuning
weight_decay: 0
warmup: 0
cooldown: 0.1
lr_schedule: constant
min_lr_ratio: 0.0
max_grad_norm: 0.0
type: muon
4 changes: 2 additions & 2 deletions src/levanter/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def decode(x):

def doc_iterator(self, split: str) -> Iterator[Tuple[np.ndarray, int, str]]:
if self.id is not None:
data = datasets.load_dataset(self.id, split=split, name=self.name, streaming=self.stream)
data = datasets.load_dataset(self.id, split=split, name=self.name, streaming=self.stream, trust_remote_code=True)
for doc in data:
yield (doc[self.audio_key]["array"], doc[self.audio_key]["sampling_rate"], doc[self.text_key])
else:
Expand Down Expand Up @@ -385,7 +385,7 @@ def _has_validation_set(self):

if self.id is not None:
dataset = datasets.load_dataset(
self.id, name=self.name, streaming=self.stream, split=self.validation_split
self.id, name=self.name, streaming=self.stream, split=self.validation_split, trust_remote_code=True
)
try:
next(iter(dataset))
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/data/sharded_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]:
def _load_dataset(self):
# obnoxiously, the dataset loading stuff doesn't work with ray because of multiprocessing
# so we have to do this hacky thing where we load the dataset in the worker
return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, **self.kwargs)
return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, trust_remote_code=True, **self.kwargs)


class TextUrlDataSource(ShardedDataSource[str]):
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]:

def doc_iterator(self, split: str):
if self.id is not None:
dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream)
dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, trust_remote_code=True)
data = dataset[split]
for doc in data:
yield doc[self.text_key]
Expand Down Expand Up @@ -1065,7 +1065,7 @@ def _has_validation_set(self):
return True

if self.id is not None:
dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation")
dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation", trust_remote_code=True)
try:
next(iter(dataset))
return True
Expand Down
9 changes: 9 additions & 0 deletions src/levanter/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,12 @@
scale_by_sophia_g,
scale_by_sophia_h,
)
from .muon import (
MuonConfig,
ScaleByMuonState
)
from .mars import (
MarsConfig,
ScaleByMarsState
)
from .kron import KronConfig
Loading