Skip to content

Commit

Permalink
initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 30, 2024
1 parent 3179f67 commit 69ea6b1
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,6 @@ def main(config: TrainLmConfig):
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")

logger.info(f"initializing model with key {model_key}")
state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))
logger.info(f"model initialized with {parameter_count(state.model)} parameters")

# TODO: fix this
tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size)
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)
Expand Down Expand Up @@ -168,6 +164,8 @@ def main(config: TrainLmConfig):
)
trainer.add_hook(epoch_checkpointer, every=1)

state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

seek_dataloader = True
if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None:
state = load_checkpoint(state, config.initialize_from_checkpoint_path)
Expand Down

0 comments on commit 69ea6b1

Please sign in to comment.