diff --git a/tools/lightning_train_net.py b/tools/lightning_train_net.py index 87cfe84feb..1239bf2a84 100644 --- a/tools/lightning_train_net.py +++ b/tools/lightning_train_net.py @@ -56,6 +56,9 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: self.start_iter = checkpointed_state["iteration"] + if self.storage is None: + self.storage = EventStorage(0) + self.storage.__enter__() self.storage.iter = self.start_iter def setup(self, stage: str): @@ -83,6 +86,7 @@ def training_step(self, batch, batch_idx): self.storage.__enter__() self.iteration_timer.trainer = weakref.proxy(self) self.iteration_timer.before_step() + if self.writers is None: self.writers = ( default_writers(self.cfg.OUTPUT_DIR, self.max_iter) if comm.is_main_process()