diff --git a/src/molearn/trainers/openmm_physics_trainer.py b/src/molearn/trainers/openmm_physics_trainer.py
index 61917e2..37b1979 100644
--- a/src/molearn/trainers/openmm_physics_trainer.py
+++ b/src/molearn/trainers/openmm_physics_trainer.py
@@ -1,6 +1,29 @@
import torch
from molearn.loss_functions import openmm_energy
from .trainer import Trainer
+import os
+
+
+soft_xml_script='''\
+
+
+
+'''
+
class OpenMM_Physics_Trainer(Trainer):
@@ -12,8 +35,8 @@ class OpenMM_Physics_Trainer(Trainer):
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
-
- def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, **kwargs):
+
+ def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, xml_file = None, soft_NB = True, **kwargs):
'''
Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy `
Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data`
@@ -25,13 +48,24 @@ def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy `
'''
+ if xml_file is None and soft_NB==True:
+ print('using soft nonbonded forces by default')
+ from molearn.utils import random_string
+ tmp_filename = f'soft_nonbonded_{random_string()}.xml'
+ with open(tmp_filename, 'w') as f:
+ f.write(soft_xml_script)
+ xml_file = ['amber14-all.xml', tmp_filename]
+ kwargs['remove_NB'] = True
+ elif xml_file is None:
+ xml_file = ['amber14-all.xml']
self.start_physics_at = start_physics_at
self.psf = physics_scaling_factor
if clamp:
clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold)
else:
clamp_kwargs = None
- self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, **kwargs)
+ self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, xml_file = xml_file, **kwargs)
+ os.remove(tmp_filename)
def common_physics_step(self, batch, latent):
'''
diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py
index 950dd6c..866e6a1 100644
--- a/src/molearn/trainers/trainer.py
+++ b/src/molearn/trainers/trainer.py
@@ -138,7 +138,18 @@ def scheduler_step(self, logs):
'''
pass
- def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None):
+ def prepare_logs(self, log_filename, log_folder=None):
+ self.log_filename = log_filename
+ if log_folder is not None:
+ if not os.path.exists(log_folder):
+ os.mkdir(log_folder)
+ if hasattr(self, "_repeat") and self._repeat >0:
+ self.log_filename = f'{log_folder}/{self._repeat}_{self.log_filename}'
+ else:
+ self.log_filename = f'{log_folder}/{self.log_filename}'
+
+
+ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None, allow_grad_in_valid=False):
'''
Calls the following in a loop:
@@ -158,12 +169,14 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre
:param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename
'''
- if log_filename is not None:
- self.log_filename = log_filename
- if log_folder is not None:
- if not os.path.exists(log_folder):
- os.mkdir(log_folder)
- self.log_filename = log_folder+'/'+self.log_filename
+ self.get_repeat(checkpoint_folder)
+ self.prepare_logs(log_filename if log_filename is not None else self.log_filename, log_folder)
+ #if log_filename is not None:
+ # self.log_filename = log_filename
+ # if log_folder is not None:
+ # if not os.path.exists(log_folder):
+ # os.mkdir(log_folder)
+ # self.log_filename = log_folder+'/'+self.log_filename
if verbose is not None:
self.verbose = verbose
@@ -173,8 +186,11 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre
time1 = time.time()
logs = self.train_epoch(epoch)
time2 = time.time()
- with torch.no_grad():
+ if allow_grad_in_valid:
logs.update(self.valid_epoch(epoch))
+ else:
+ with torch.no_grad():
+ logs.update(self.valid_epoch(epoch))
time3 = time.time()
self.scheduler_step(logs)
if self.best is None or self.best > logs['valid_loss']:
@@ -381,11 +397,11 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'
'atoms': self._data.atoms,
'std': self.std,
'mean': self.mean},
- f'{checkpoint_folder}/last.ckpt')
+ f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt')
if self.best is None or self.best > valid_loss:
- filename = f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'
- shutil.copyfile(f'{checkpoint_folder}/last.ckpt', filename)
+ filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt'
+ shutil.copyfile(f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', filename)
if self.best is not None:
os.remove(self.best_name)
self.best_name = filename
@@ -424,6 +440,20 @@ def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_opt
epoch = checkpoint['epoch']
self.epoch = epoch+1
+ def get_repeat(self, checkpoint_folder):
+ if not os.path.exists(checkpoint_folder):
+ os.mkdir(checkpoint_folder)
+ if not hasattr(self, '_repeat'):
+ self._repeat = 0
+ for i in range(1000):
+ if not os.path.exists(checkpoint_folder+f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'):
+ break#os.mkdir(checkpoint_folder)
+ else:
+ self._repeat += 1
+ else:
+ raise Exception('Something went wrong, you surely havnt done 1000 repeats?')
+
+
if __name__=='__main__':
pass