diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py index 3530541..866e6a1 100644 --- a/src/molearn/trainers/trainer.py +++ b/src/molearn/trainers/trainer.py @@ -143,7 +143,10 @@ def prepare_logs(self, log_filename, log_folder=None): if log_folder is not None: if not os.path.exists(log_folder): os.mkdir(log_folder) - self.log_filename = log_folder+'/'+self.log_filename + if hasattr(self, "_repeat") and self._repeat >0: + self.log_filename = f'{log_folder}/{self._repeat}_{self.log_filename}' + else: + self.log_filename = f'{log_folder}/{self.log_filename}' def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None, allow_grad_in_valid=False): @@ -166,6 +169,7 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre :param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename ''' + self.get_repeat(checkpoint_folder) self.prepare_logs(log_filename if log_filename is not None else self.log_filename, log_folder) #if log_filename is not None: # self.log_filename = log_filename @@ -382,7 +386,6 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss' :param str checkpoint_folder: The folder in which to save the checkpoint. :param str loss_key: (default: 'valid_loss') The key with which to get loss from valid_logs. ''' - self.get_repeat(checkpoint_folder) valid_loss = valid_logs[loss_key] if not os.path.exists(checkpoint_folder): os.mkdir(checkpoint_folder) @@ -394,11 +397,11 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss' 'atoms': self._data.atoms, 'std': self.std, 'mean': self.mean}, - f'{checkpoint_folder}/last{self._repeat}.ckpt') + f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt') if self.best is None or self.best > valid_loss: - filename = f'{checkpoint_folder}/checkpoint{self._repeat}_epoch{epoch}_loss{valid_loss}.ckpt' - shutil.copyfile(f'{checkpoint_folder}/last{self._repeat}.ckpt', filename) + filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt' + shutil.copyfile(f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', filename) if self.best is not None: os.remove(self.best_name) self.best_name = filename @@ -441,14 +444,12 @@ def get_repeat(self, checkpoint_folder): if not os.path.exists(checkpoint_folder): os.mkdir(checkpoint_folder) if not hasattr(self, '_repeat'): - _repeat = 0 - self._repeat = f'_{_repeat}' if _repeat>0 else '' + self._repeat = 0 for i in range(1000): - if not os.path.exists(checkpoint_folder+f'/last{self._repeat}.ckpt'): + if not os.path.exists(checkpoint_folder+f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'): break#os.mkdir(checkpoint_folder) else: - _repeat += 1 - self._repeat = f'_{_repeat}' if _repeat>0 else '' + self._repeat += 1 else: raise Exception('Something went wrong, you surely havnt done 1000 repeats?')