Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Avoid overriding results and default soft nonbonded energy #13

Merged
merged 4 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions src/molearn/trainers/openmm_physics_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
import torch
from molearn.loss_functions import openmm_energy
from .trainer import Trainer
import os


soft_xml_script='''\
<ForceField>
<Script>
import openmm as mm
nb = mm.CustomNonbondedForce('C/((r/0.2)^4+1)')
nb.addGlobalParameter('C', 1.0)
sys.addForce(nb)
for i in range(sys.getNumParticles()):
nb.addParticle([])
exclusions = set()
for bond in data.bonds:
exclusions.add((min(bond.atom1, bond.atom2), max(bond.atom1, bond.atom2)))
for angle in data.angles:
exclusions.add((min(angle[0], angle[2]), max(angle[0], angle[2])))
for a1, a2 in exclusions:
nb.addExclusion(a1, a2)
</Script>
</ForceField>
'''



class OpenMM_Physics_Trainer(Trainer):
Expand All @@ -12,8 +35,8 @@ class OpenMM_Physics_Trainer(Trainer):
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, **kwargs):
def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, xml_file = None, soft_NB = True, **kwargs):
'''
Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy <molearn.loss_functions.openmm_energy>`
Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data<molearn.trainer.Trainer.set_data>`
Expand All @@ -25,13 +48,24 @@ def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy <molearn.loss_functions.openmm_energy>`

'''
if xml_file is None and soft_NB==True:
print('using soft nonbonded forces by default')
from molearn.utils import random_string
tmp_filename = f'soft_nonbonded_{random_string()}.xml'
with open(tmp_filename, 'w') as f:
f.write(soft_xml_script)
xml_file = ['amber14-all.xml', tmp_filename]
kwargs['remove_NB'] = True
elif xml_file is None:
xml_file = ['amber14-all.xml']
self.start_physics_at = start_physics_at
self.psf = physics_scaling_factor
if clamp:
clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold)
else:
clamp_kwargs = None
self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, **kwargs)
self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, xml_file = xml_file, **kwargs)
os.remove(tmp_filename)

def common_physics_step(self, batch, latent):
'''
Expand Down
52 changes: 41 additions & 11 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,18 @@ def scheduler_step(self, logs):
'''
pass

def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None):
def prepare_logs(self, log_filename, log_folder=None):
self.log_filename = log_filename
if log_folder is not None:
if not os.path.exists(log_folder):
os.mkdir(log_folder)
if hasattr(self, "_repeat") and self._repeat >0:
self.log_filename = f'{log_folder}/{self._repeat}_{self.log_filename}'
else:
self.log_filename = f'{log_folder}/{self.log_filename}'


def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None, allow_grad_in_valid=False):
'''
Calls the following in a loop:

Expand All @@ -158,12 +169,14 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre
:param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename

'''
if log_filename is not None:
self.log_filename = log_filename
if log_folder is not None:
if not os.path.exists(log_folder):
os.mkdir(log_folder)
self.log_filename = log_folder+'/'+self.log_filename
self.get_repeat(checkpoint_folder)
self.prepare_logs(log_filename if log_filename is not None else self.log_filename, log_folder)
#if log_filename is not None:
# self.log_filename = log_filename
# if log_folder is not None:
# if not os.path.exists(log_folder):
# os.mkdir(log_folder)
# self.log_filename = log_folder+'/'+self.log_filename
if verbose is not None:
self.verbose = verbose

Expand All @@ -173,8 +186,11 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre
time1 = time.time()
logs = self.train_epoch(epoch)
time2 = time.time()
with torch.no_grad():
if allow_grad_in_valid:
logs.update(self.valid_epoch(epoch))
else:
with torch.no_grad():
logs.update(self.valid_epoch(epoch))
time3 = time.time()
self.scheduler_step(logs)
if self.best is None or self.best > logs['valid_loss']:
Expand Down Expand Up @@ -381,11 +397,11 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'
'atoms': self._data.atoms,
'std': self.std,
'mean': self.mean},
f'{checkpoint_folder}/last.ckpt')
f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt')

if self.best is None or self.best > valid_loss:
filename = f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last.ckpt', filename)
filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', filename)
if self.best is not None:
os.remove(self.best_name)
self.best_name = filename
Expand Down Expand Up @@ -424,6 +440,20 @@ def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_opt
epoch = checkpoint['epoch']
self.epoch = epoch+1

def get_repeat(self, checkpoint_folder):
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
if not hasattr(self, '_repeat'):
self._repeat = 0
for i in range(1000):
if not os.path.exists(checkpoint_folder+f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'):
break#os.mkdir(checkpoint_folder)
else:
self._repeat += 1
else:
raise Exception('Something went wrong, you surely havnt done 1000 repeats?')



if __name__=='__main__':
pass