diff --git a/src/molearn/trainers/openmm_physics_trainer.py b/src/molearn/trainers/openmm_physics_trainer.py index 61917e2..37b1979 100644 --- a/src/molearn/trainers/openmm_physics_trainer.py +++ b/src/molearn/trainers/openmm_physics_trainer.py @@ -1,6 +1,29 @@ import torch from molearn.loss_functions import openmm_energy from .trainer import Trainer +import os + + +soft_xml_script='''\ + + + +''' + class OpenMM_Physics_Trainer(Trainer): @@ -12,8 +35,8 @@ class OpenMM_Physics_Trainer(Trainer): ''' def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, **kwargs): + + def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, xml_file = None, soft_NB = True, **kwargs): ''' Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy ` Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data` @@ -25,13 +48,24 @@ def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp :param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy ` ''' + if xml_file is None and soft_NB==True: + print('using soft nonbonded forces by default') + from molearn.utils import random_string + tmp_filename = f'soft_nonbonded_{random_string()}.xml' + with open(tmp_filename, 'w') as f: + f.write(soft_xml_script) + xml_file = ['amber14-all.xml', tmp_filename] + kwargs['remove_NB'] = True + elif xml_file is None: + xml_file = ['amber14-all.xml'] self.start_physics_at = start_physics_at self.psf = physics_scaling_factor if clamp: clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold) else: clamp_kwargs = None - self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, **kwargs) + self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, xml_file = xml_file, **kwargs) + os.remove(tmp_filename) def common_physics_step(self, batch, latent): ''' diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py index 950dd6c..866e6a1 100644 --- a/src/molearn/trainers/trainer.py +++ b/src/molearn/trainers/trainer.py @@ -138,7 +138,18 @@ def scheduler_step(self, logs): ''' pass - def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None): + def prepare_logs(self, log_filename, log_folder=None): + self.log_filename = log_filename + if log_folder is not None: + if not os.path.exists(log_folder): + os.mkdir(log_folder) + if hasattr(self, "_repeat") and self._repeat >0: + self.log_filename = f'{log_folder}/{self._repeat}_{self.log_filename}' + else: + self.log_filename = f'{log_folder}/{self.log_filename}' + + + def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None, allow_grad_in_valid=False): ''' Calls the following in a loop: @@ -158,12 +169,14 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre :param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename ''' - if log_filename is not None: - self.log_filename = log_filename - if log_folder is not None: - if not os.path.exists(log_folder): - os.mkdir(log_folder) - self.log_filename = log_folder+'/'+self.log_filename + self.get_repeat(checkpoint_folder) + self.prepare_logs(log_filename if log_filename is not None else self.log_filename, log_folder) + #if log_filename is not None: + # self.log_filename = log_filename + # if log_folder is not None: + # if not os.path.exists(log_folder): + # os.mkdir(log_folder) + # self.log_filename = log_folder+'/'+self.log_filename if verbose is not None: self.verbose = verbose @@ -173,8 +186,11 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre time1 = time.time() logs = self.train_epoch(epoch) time2 = time.time() - with torch.no_grad(): + if allow_grad_in_valid: logs.update(self.valid_epoch(epoch)) + else: + with torch.no_grad(): + logs.update(self.valid_epoch(epoch)) time3 = time.time() self.scheduler_step(logs) if self.best is None or self.best > logs['valid_loss']: @@ -381,11 +397,11 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss' 'atoms': self._data.atoms, 'std': self.std, 'mean': self.mean}, - f'{checkpoint_folder}/last.ckpt') + f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt') if self.best is None or self.best > valid_loss: - filename = f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt' - shutil.copyfile(f'{checkpoint_folder}/last.ckpt', filename) + filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt' + shutil.copyfile(f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', filename) if self.best is not None: os.remove(self.best_name) self.best_name = filename @@ -424,6 +440,20 @@ def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_opt epoch = checkpoint['epoch'] self.epoch = epoch+1 + def get_repeat(self, checkpoint_folder): + if not os.path.exists(checkpoint_folder): + os.mkdir(checkpoint_folder) + if not hasattr(self, '_repeat'): + self._repeat = 0 + for i in range(1000): + if not os.path.exists(checkpoint_folder+f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'): + break#os.mkdir(checkpoint_folder) + else: + self._repeat += 1 + else: + raise Exception('Something went wrong, you surely havnt done 1000 repeats?') + + if __name__=='__main__': pass