Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling #28

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 70 additions & 42 deletions src/molearn/trainers/openmm_physics_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os


soft_xml_script='''\
soft_xml_script = """\
<ForceField>
<Script>
import openmm as mm
Expand All @@ -22,95 +22,121 @@
nb.addExclusion(a1, a2)
</Script>
</ForceField>
'''

"""


class OpenMM_Physics_Trainer(Trainer):
'''
"""
OpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step.
An extra 'physics_loss' is calculated using OpenMM and the forces are inserted into backwards pass.
To use this trainer requires the additional step of calling :func:`prepare_physics <molearn.trainers.OpenMM_Physics_Trainer.prepare_physics>`.

'''
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, xml_file = None, soft_NB = True, **kwargs):
'''

def prepare_physics(
self,
physics_scaling_factor=0.1,
clamp_threshold=1e8,
clamp=False,
start_physics_at=10,
xml_file=None,
soft_NB=True,
**kwargs,
):
"""
Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy <molearn.loss_functions.openmm_energy>`
Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data<molearn.trainer.Trainer.set_data>`

:param float physics_scaling_factor: scaling factor saved to ``self.psf`` that is used in :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>`. Defaults to 0.1
:param float clamp_threshold: if ``clamp=True`` is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8
:param bool clamp: Whether to clamp the forces. Defaults to False
:param int start_physics_at: As of yet unused parameter saved as ``self.start_physics_at = start_physics_at``. Default: 0
:param int start_physics_at: At which epoch the physics loss will be added to the loss. Default: 10
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy <molearn.loss_functions.openmm_energy>`

'''
if xml_file is None and soft_NB==True:
print('using soft nonbonded forces by default')
"""
if xml_file is None and soft_NB:
print("using soft nonbonded forces by default")
from molearn.utils import random_string
tmp_filename = f'soft_nonbonded_{random_string()}.xml'
with open(tmp_filename, 'w') as f:

tmp_filename = f"soft_nonbonded_{random_string()}.xml"
with open(tmp_filename, "w") as f:
f.write(soft_xml_script)
xml_file = ['amber14-all.xml', tmp_filename]
kwargs['remove_NB'] = True
xml_file = ["amber14-all.xml", tmp_filename]
kwargs["remove_NB"] = True
elif xml_file is None:
xml_file = ['amber14-all.xml']
xml_file = ["amber14-all.xml"]
self.start_physics_at = start_physics_at
self.psf = physics_scaling_factor
if clamp:
clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold)
else:
clamp_kwargs = None
self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, xml_file = xml_file, **kwargs)
self.physics_loss = openmm_energy(
self.mol,
self.std,
clamp=clamp_kwargs,
platform="CUDA" if self.device == torch.device("cuda") else "Reference",
atoms=self._data.atoms,
xml_file=xml_file,
**kwargs,
)
os.remove(tmp_filename)

def common_physics_step(self, batch, latent):
'''
"""
Called from both :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>` and :func:`valid_step <molearn.trainers.OpenMM_Physics_Trainer.valid_step>`.
Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as ``self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss``.

:param torch.Tensor batch: tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine ``n_atoms``
:param torch.Tensor latent: tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
'''
alpha = torch.rand(int(len(batch)//2), 1, 1).type_as(latent)
latent_interpolated = (1-alpha)*latent[:-1:2] + alpha*latent[1::2]
"""
alpha = torch.rand(int(len(batch) // 2), 1, 1).type_as(latent)
latent_interpolated = (1 - alpha) * latent[:-1:2] + alpha * latent[1::2]

generated = self.autoencoder.decode(latent_interpolated)[:, :, :batch.size(2)]
self._internal['generated'] = generated
generated = self.autoencoder.decode(latent_interpolated)[:, :, : batch.size(2)]
self._internal["generated"] = generated
energy = self.physics_loss(generated)
energy[energy.isinf()] = 1e35
energy = torch.clamp(energy, max=1e34)
energy = energy.nanmean()

return {'physics_loss':energy} # a if not energy.isinf() else torch.tensor(0.0)}
return {
"physics_loss": energy
} # a if not energy.isinf() else torch.tensor(0.0)}

def train_step(self, batch):
'''
"""
This method overrides :func:`Trainer.train_step <molearn.trainers.Trainer.train_step>` and adds an additional 'Physics_loss' term.

Mse_loss and physics loss are summed (``Mse_loss + scale*physics_loss``)with a scaling factor ``self.psf*mse_loss/Physics_loss``. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a ``torch.no_grad`` context manager the gradients are not computed.
This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.

Called from :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`.

:param torch.Tensor batch: tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch <molearn.trainers.Trainer.train_epoch>` will call ``result['loss'].backwards()`` to obtain gradients.
:rtype: dict
'''
"""

results = self.common_step(batch)
results.update(self.common_physics_step(batch, self._internal['encoded']))
results.update(self.common_physics_step(batch, self._internal["encoded"]))

with torch.no_grad():
scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5)
final_loss = results['mse_loss']+scale*results['physics_loss']
results['loss'] = final_loss
if self.epoch == self.start_physics_at:
self.phy_scale = self._get_scale(
results["mse_loss"],
results["physics_loss"],
self.psf,
)
if self.epoch >= self.start_physics_at:
final_loss = results["mse_loss"] + self.phy_scale * results["physics_loss"]
else:
final_loss = results["mse_loss"]

results["loss"] = final_loss
return results

def valid_step(self, batch):
'''
"""
This method overrides :func:`Trainer.valid_step <molearn.trainers.Trainer.valid_step>` and adds an additional 'Physics_loss' term.

Differently to :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>` this method sums the logs of mse_loss and physics_loss ``final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])``
Expand All @@ -121,15 +147,17 @@ def valid_step(self, batch):
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined.
:rtype: dict

'''
"""

results = self.common_step(batch)
results.update(self.common_physics_step(batch, self._internal['encoded']))
results.update(self.common_physics_step(batch, self._internal["encoded"]))
# scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5)
final_loss = torch.log(results['mse_loss'])+self.psf*torch.log(results['physics_loss'])
results['loss'] = final_loss
final_loss = torch.log(results["mse_loss"]) + self.psf * torch.log(
results["physics_loss"]
)
results["loss"] = final_loss
return results


if __name__=='__main__':
if __name__ == "__main__":
pass
16 changes: 16 additions & 0 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import glob
import shutil
import math
import numpy as np
import time
import torch
Expand Down Expand Up @@ -541,6 +542,21 @@ def get_repeat(self, checkpoint_folder):
"Something went wrong, you surely havnt done 1000 repeats?"
)

def _get_scale(
self, cur_mse_loss: float, loss_to_scale: float, scale_scale: float = 1.0
):
"""
get a scaling factor to scale a loss to be in the same order of magnitude like `cur_mse_loss`

:param float cur_mse_loss: the mse loss of the current epoch
:param float loss_to_scale: the loss that should be scaled to be in the same order of magnitude as the `cur_mse_loss`
:param float scale_scale: scale to in-/ decrease the scale further
:return float scaling_factor: the calculated scaling factor for the `loss_to_scale`
"""
mag_mse = math.floor(math.log10(cur_mse_loss if cur_mse_loss > 0 else 1e-32))
mag_phy = math.floor(math.log10(loss_to_scale if loss_to_scale > 0 else 1e-32))
return 10 ** (mag_mse - mag_phy) * scale_scale


if __name__ == "__main__":
pass
Loading