Skip to content

Commit

Permalink
Merge pull request #28 from Degiacomi-Lab/scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
degiacom authored Dec 16, 2024
2 parents d4700a4 + 8eaa644 commit 8703db0
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 42 deletions.
112 changes: 70 additions & 42 deletions src/molearn/trainers/openmm_physics_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os


soft_xml_script='''\
soft_xml_script = """\
<ForceField>
<Script>
import openmm as mm
Expand All @@ -22,95 +22,121 @@
nb.addExclusion(a1, a2)
</Script>
</ForceField>
'''

"""


class OpenMM_Physics_Trainer(Trainer):
'''
"""
OpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step.
An extra 'physics_loss' is calculated using OpenMM and the forces are inserted into backwards pass.
To use this trainer requires the additional step of calling :func:`prepare_physics <molearn.trainers.OpenMM_Physics_Trainer.prepare_physics>`.
'''
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, xml_file = None, soft_NB = True, **kwargs):
'''

def prepare_physics(
self,
physics_scaling_factor=0.1,
clamp_threshold=1e8,
clamp=False,
start_physics_at=10,
xml_file=None,
soft_NB=True,
**kwargs,
):
"""
Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy <molearn.loss_functions.openmm_energy>`
Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data<molearn.trainer.Trainer.set_data>`
:param float physics_scaling_factor: scaling factor saved to ``self.psf`` that is used in :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>`. Defaults to 0.1
:param float clamp_threshold: if ``clamp=True`` is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8
:param bool clamp: Whether to clamp the forces. Defaults to False
:param int start_physics_at: As of yet unused parameter saved as ``self.start_physics_at = start_physics_at``. Default: 0
:param int start_physics_at: At which epoch the physics loss will be added to the loss. Default: 10
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy <molearn.loss_functions.openmm_energy>`
'''
if xml_file is None and soft_NB==True:
print('using soft nonbonded forces by default')
"""
if xml_file is None and soft_NB:
print("using soft nonbonded forces by default")
from molearn.utils import random_string
tmp_filename = f'soft_nonbonded_{random_string()}.xml'
with open(tmp_filename, 'w') as f:

tmp_filename = f"soft_nonbonded_{random_string()}.xml"
with open(tmp_filename, "w") as f:
f.write(soft_xml_script)
xml_file = ['amber14-all.xml', tmp_filename]
kwargs['remove_NB'] = True
xml_file = ["amber14-all.xml", tmp_filename]
kwargs["remove_NB"] = True
elif xml_file is None:
xml_file = ['amber14-all.xml']
xml_file = ["amber14-all.xml"]
self.start_physics_at = start_physics_at
self.psf = physics_scaling_factor
if clamp:
clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold)
else:
clamp_kwargs = None
self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, xml_file = xml_file, **kwargs)
self.physics_loss = openmm_energy(
self.mol,
self.std,
clamp=clamp_kwargs,
platform="CUDA" if self.device == torch.device("cuda") else "Reference",
atoms=self._data.atoms,
xml_file=xml_file,
**kwargs,
)
os.remove(tmp_filename)

def common_physics_step(self, batch, latent):
'''
"""
Called from both :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>` and :func:`valid_step <molearn.trainers.OpenMM_Physics_Trainer.valid_step>`.
Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as ``self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss``.
:param torch.Tensor batch: tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine ``n_atoms``
:param torch.Tensor latent: tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
'''
alpha = torch.rand(int(len(batch)//2), 1, 1).type_as(latent)
latent_interpolated = (1-alpha)*latent[:-1:2] + alpha*latent[1::2]
"""
alpha = torch.rand(int(len(batch) // 2), 1, 1).type_as(latent)
latent_interpolated = (1 - alpha) * latent[:-1:2] + alpha * latent[1::2]

generated = self.autoencoder.decode(latent_interpolated)[:, :, :batch.size(2)]
self._internal['generated'] = generated
generated = self.autoencoder.decode(latent_interpolated)[:, :, : batch.size(2)]
self._internal["generated"] = generated
energy = self.physics_loss(generated)
energy[energy.isinf()] = 1e35
energy = torch.clamp(energy, max=1e34)
energy = energy.nanmean()

return {'physics_loss':energy} # a if not energy.isinf() else torch.tensor(0.0)}
return {
"physics_loss": energy
} # a if not energy.isinf() else torch.tensor(0.0)}

def train_step(self, batch):
'''
"""
This method overrides :func:`Trainer.train_step <molearn.trainers.Trainer.train_step>` and adds an additional 'Physics_loss' term.
Mse_loss and physics loss are summed (``Mse_loss + scale*physics_loss``)with a scaling factor ``self.psf*mse_loss/Physics_loss``. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a ``torch.no_grad`` context manager the gradients are not computed.
This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.
Called from :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`.
:param torch.Tensor batch: tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch <molearn.trainers.Trainer.train_epoch>` will call ``result['loss'].backwards()`` to obtain gradients.
:rtype: dict
'''
"""

results = self.common_step(batch)
results.update(self.common_physics_step(batch, self._internal['encoded']))
results.update(self.common_physics_step(batch, self._internal["encoded"]))

with torch.no_grad():
scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5)
final_loss = results['mse_loss']+scale*results['physics_loss']
results['loss'] = final_loss
if self.epoch == self.start_physics_at:
self.phy_scale = self._get_scale(
results["mse_loss"],
results["physics_loss"],
self.psf,
)
if self.epoch >= self.start_physics_at:
final_loss = results["mse_loss"] + self.phy_scale * results["physics_loss"]
else:
final_loss = results["mse_loss"]

results["loss"] = final_loss
return results

def valid_step(self, batch):
'''
"""
This method overrides :func:`Trainer.valid_step <molearn.trainers.Trainer.valid_step>` and adds an additional 'Physics_loss' term.
Differently to :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>` this method sums the logs of mse_loss and physics_loss ``final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])``
Expand All @@ -121,15 +147,17 @@ def valid_step(self, batch):
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined.
:rtype: dict
'''
"""

results = self.common_step(batch)
results.update(self.common_physics_step(batch, self._internal['encoded']))
results.update(self.common_physics_step(batch, self._internal["encoded"]))
# scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5)
final_loss = torch.log(results['mse_loss'])+self.psf*torch.log(results['physics_loss'])
results['loss'] = final_loss
final_loss = torch.log(results["mse_loss"]) + self.psf * torch.log(
results["physics_loss"]
)
results["loss"] = final_loss
return results


if __name__=='__main__':
if __name__ == "__main__":
pass
16 changes: 16 additions & 0 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import glob
import shutil
import math
import numpy as np
import time
import torch
Expand Down Expand Up @@ -541,6 +542,21 @@ def get_repeat(self, checkpoint_folder):
"Something went wrong, you surely havnt done 1000 repeats?"
)

def _get_scale(
self, cur_mse_loss: float, loss_to_scale: float, scale_scale: float = 1.0
):
"""
get a scaling factor to scale a loss to be in the same order of magnitude like `cur_mse_loss`
:param float cur_mse_loss: the mse loss of the current epoch
:param float loss_to_scale: the loss that should be scaled to be in the same order of magnitude as the `cur_mse_loss`
:param float scale_scale: scale to in-/ decrease the scale further
:return float scaling_factor: the calculated scaling factor for the `loss_to_scale`
"""
mag_mse = math.floor(math.log10(cur_mse_loss if cur_mse_loss > 0 else 1e-32))
mag_phy = math.floor(math.log10(loss_to_scale if loss_to_scale > 0 else 1e-32))
return 10 ** (mag_mse - mag_phy) * scale_scale


if __name__ == "__main__":
pass

0 comments on commit 8703db0

Please sign in to comment.