diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index cf78b81f8..2ce2135fd 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -134,10 +134,6 @@ def main(config: TrainLmConfig): if vocab_size != Vocab.size: logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") - logger.info(f"initializing model with key {model_key}") - state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) - logger.info(f"model initialized with {parameter_count(state.model)} parameters") - # TODO: fix this tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size) # TokenSeqDataset is config.data.train_set(Pos.size, key=data_key) @@ -168,6 +164,8 @@ def main(config: TrainLmConfig): ) trainer.add_hook(epoch_checkpointer, every=1) + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + seek_dataloader = True if int(state.step) == 0 and config.initialize_from_checkpoint_path is not None: state = load_checkpoint(state, config.initialize_from_checkpoint_path)