Skip to content

Commit

Permalink
avoid overriding checkpoints and log files
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Musson committed Oct 26, 2023
1 parent 9f77650 commit c4fb49f
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'
:param str checkpoint_folder: The folder in which to save the checkpoint.
:param str loss_key: (default: 'valid_loss') The key with which to get loss from valid_logs.
'''
self.get_repeat(checkpoint_folder)
valid_loss = valid_logs[loss_key]
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
Expand All @@ -393,11 +394,11 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'
'atoms': self._data.atoms,
'std': self.std,
'mean': self.mean},
f'{checkpoint_folder}/last.ckpt')
f'{checkpoint_folder}/last{self._repeat}.ckpt')

if self.best is None or self.best > valid_loss:
filename = f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last.ckpt', filename)
filename = f'{checkpoint_folder}/checkpoint{self._repeat}_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last{self._repeat}.ckpt', filename)
if self.best is not None:
os.remove(self.best_name)
self.best_name = filename
Expand Down Expand Up @@ -436,6 +437,22 @@ def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_opt
epoch = checkpoint['epoch']
self.epoch = epoch+1

def get_repeat(self, checkpoint_folder):
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
if not hasattr(self, '_repeat'):
_repeat = 0
self._repeat = f'_{_repeat}' if _repeat>0 else ''
for i in range(1000):
if not os.path.exists(checkpoint_folder+f'/last{self._repeat}.ckpt'):
break#os.mkdir(checkpoint_folder)
else:
_repeat += 1
self._repeat = f'_{_repeat}' if _repeat>0 else ''
else:
raise Exception('Something went wrong, you surely havnt done 1000 repeats?')



if __name__=='__main__':
pass

0 comments on commit c4fb49f

Please sign in to comment.