Skip to content

Commit

Permalink
read olmo model from HF
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Sep 11, 2024
1 parent aa460f5 commit 7eaca42
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
22 changes: 12 additions & 10 deletions config/llama_7b_with_olmo_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ data:
tokenizer: "allenai/OLMo-1B"
model: # 7B class model
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
use_bias: false
use_layer_norm_weight: false
# seq_len: 2048
# hidden_dim: 4096
# intermediate_dim: 11008
# num_layers: 32
# num_heads: 32
# num_kv_heads: 32
# use_flash_attention: True
# use_bias: false
# use_layer_norm_weight: false
initialize_from_hf: "allenai/OLMo-1.7-7B-hf@step476000-tokens1995B"
use_hf_model_config: true
#flash_attention_block_size: 1024
trainer:
tracker:
Expand All @@ -25,7 +27,7 @@ trainer:
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 8
train_batch_size: 64
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
Expand Down
4 changes: 4 additions & 0 deletions src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@ def init_state_and_model(model_init, training_key):
return state

trainer_state_shape = eqx.filter_eval_shape(init_state_and_model, model_init, training_key)
if self.config.reset_optimizer_state:
saveable_train_state = dataclass.replace(saveable_train_state, optimizer=False)
saveable_train_state = saveable_training_mask(trainer_state_shape, is_trainable)

state = load_checkpoint_or_initialize(
Expand Down Expand Up @@ -584,6 +586,8 @@ class TrainerConfig:
# whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes
shutdown_at_exit: Union[bool, float] = False

reset_optimizer_state: bool = False

@property
def TrainBatch(self):
return Axis("batch", self.train_batch_size)
Expand Down

0 comments on commit 7eaca42

Please sign in to comment.