diff --git a/src/levanter/main/train_lm.py b/src/levanter/main/train_lm.py index 17de8f52d..9c511d31a 100644 --- a/src/levanter/main/train_lm.py +++ b/src/levanter/main/train_lm.py @@ -244,8 +244,6 @@ def compute_log_probs(model, example): ## OK, actually run training! trainer.train(state, train_loader) - - ## OK, actually run training! # checkpointer.on_step(last_step, force=True) diff --git a/src/levanter/trainer.py b/src/levanter/trainer.py index 1ebc22122..3973c025a 100644 --- a/src/levanter/trainer.py +++ b/src/levanter/trainer.py @@ -376,26 +376,12 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi while int(state.step) < self.num_train_steps: with capture_time() as loading_time: example = next(iter_data) - # while int(state.step) < target_steps and (epochs is None or current_epoch < epochs): - # current_epoch += 1 - # print(f"Starting epoch {current_epoch}") - # levanter.tracker.log_metrics({"epochs": current_epoch }, step=state.step) info = self.train_step(state, example) state = info.state if run_hooks: with capture_time() as hook_time: self.run_hooks(info) - # while True: - # try: - # with capture_time() as loading_time: - # example = next(iter_data) - # except StopIteration: - # # End of DataLoader iterator, proceed to next epoch - # train_loader = train_loader.iter_from_step(int(state.step)) - # print(f"End of epoch {current_epoch}") - # levanter.tracker.log_metrics({"epochs": current_epoch }, step=state.step) - # current_epoch += 1 levanter.tracker.log_metrics({"throughput/hook_time": hook_time()}, step=info.step)