From c28e8cb7322bbc5f545ee6ade5df5fb3063842f5 Mon Sep 17 00:00:00 2001
From: gwirn <71886945+gwirn@users.noreply.github.com>
Date: Fri, 8 Nov 2024 13:01:15 +0100
Subject: [PATCH 1/2] only formatting
---
.../trainers/openmm_physics_trainer.py | 98 ++++++++++++-------
1 file changed, 60 insertions(+), 38 deletions(-)
diff --git a/src/molearn/trainers/openmm_physics_trainer.py b/src/molearn/trainers/openmm_physics_trainer.py
index 37b1979..1398988 100644
--- a/src/molearn/trainers/openmm_physics_trainer.py
+++ b/src/molearn/trainers/openmm_physics_trainer.py
@@ -4,7 +4,7 @@
import os
-soft_xml_script='''\
+soft_xml_script = """\
-'''
-
+"""
class OpenMM_Physics_Trainer(Trainer):
- '''
+ """
OpenMM_Physics_Trainer subclasses Trainer and replaces the valid_step and train_step.
An extra 'physics_loss' is calculated using OpenMM and the forces are inserted into backwards pass.
To use this trainer requires the additional step of calling :func:`prepare_physics `.
- '''
+ """
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
-
- def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp=False, start_physics_at=0, xml_file = None, soft_NB = True, **kwargs):
- '''
+
+ def prepare_physics(
+ self,
+ physics_scaling_factor=0.1,
+ clamp_threshold=1e8,
+ clamp=False,
+ start_physics_at=0,
+ xml_file=None,
+ soft_NB=True,
+ **kwargs,
+ ):
+ """
Create ``self.physics_loss`` object from :func:`loss_functions.openmm_energy `
Needs ``self.mol``, ``self.std``, and ``self._data.atoms`` to have been set with :func:`Trainer.set_data`
@@ -47,70 +56,81 @@ def prepare_physics(self, physics_scaling_factor=0.1, clamp_threshold=1e8, clamp
:param int start_physics_at: As of yet unused parameter saved as ``self.start_physics_at = start_physics_at``. Default: 0
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy `
- '''
- if xml_file is None and soft_NB==True:
- print('using soft nonbonded forces by default')
+ """
+ if xml_file is None and soft_NB == True:
+ print("using soft nonbonded forces by default")
from molearn.utils import random_string
- tmp_filename = f'soft_nonbonded_{random_string()}.xml'
- with open(tmp_filename, 'w') as f:
+
+ tmp_filename = f"soft_nonbonded_{random_string()}.xml"
+ with open(tmp_filename, "w") as f:
f.write(soft_xml_script)
- xml_file = ['amber14-all.xml', tmp_filename]
- kwargs['remove_NB'] = True
+ xml_file = ["amber14-all.xml", tmp_filename]
+ kwargs["remove_NB"] = True
elif xml_file is None:
- xml_file = ['amber14-all.xml']
+ xml_file = ["amber14-all.xml"]
self.start_physics_at = start_physics_at
self.psf = physics_scaling_factor
if clamp:
clamp_kwargs = dict(max=clamp_threshold, min=-clamp_threshold)
else:
clamp_kwargs = None
- self.physics_loss = openmm_energy(self.mol, self.std, clamp=clamp_kwargs, platform='CUDA' if self.device == torch.device('cuda') else 'Reference', atoms=self._data.atoms, xml_file = xml_file, **kwargs)
+ self.physics_loss = openmm_energy(
+ self.mol,
+ self.std,
+ clamp=clamp_kwargs,
+ platform="CUDA" if self.device == torch.device("cuda") else "Reference",
+ atoms=self._data.atoms,
+ xml_file=xml_file,
+ **kwargs,
+ )
os.remove(tmp_filename)
def common_physics_step(self, batch, latent):
- '''
+ """
Called from both :func:`train_step ` and :func:`valid_step `.
Takes random interpolations between adjacent samples latent vectors. These are decoded (decoded structures saved as ``self._internal['generated'] = generated if needed elsewhere) and the energy terms calculated with ``self.physics_loss``.
:param torch.Tensor batch: tensor of shape [batch_size, 3, n_atoms]. Give access to the mini-batch of structures. This is used to determine ``n_atoms``
:param torch.Tensor latent: tensor shape [batch_size, 2, 1]. Pass the encoded vectors of the mini-batch.
- '''
- alpha = torch.rand(int(len(batch)//2), 1, 1).type_as(latent)
- latent_interpolated = (1-alpha)*latent[:-1:2] + alpha*latent[1::2]
+ """
+ alpha = torch.rand(int(len(batch) // 2), 1, 1).type_as(latent)
+ latent_interpolated = (1 - alpha) * latent[:-1:2] + alpha * latent[1::2]
- generated = self.autoencoder.decode(latent_interpolated)[:, :, :batch.size(2)]
- self._internal['generated'] = generated
+ generated = self.autoencoder.decode(latent_interpolated)[:, :, : batch.size(2)]
+ self._internal["generated"] = generated
energy = self.physics_loss(generated)
energy[energy.isinf()] = 1e35
energy = torch.clamp(energy, max=1e34)
energy = energy.nanmean()
- return {'physics_loss':energy} # a if not energy.isinf() else torch.tensor(0.0)}
+ return {
+ "physics_loss": energy
+ } # a if not energy.isinf() else torch.tensor(0.0)}
def train_step(self, batch):
- '''
+ """
This method overrides :func:`Trainer.train_step ` and adds an additional 'Physics_loss' term.
Mse_loss and physics loss are summed (``Mse_loss + scale*physics_loss``)with a scaling factor ``self.psf*mse_loss/Physics_loss``. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a ``torch.no_grad`` context manager the gradients are not computed.
- This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.
+ This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.
Called from :func:`Trainer.train_epoch `.
:param torch.Tensor batch: tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch ` will call ``result['loss'].backwards()`` to obtain gradients.
:rtype: dict
- '''
-
+ """
+
results = self.common_step(batch)
- results.update(self.common_physics_step(batch, self._internal['encoded']))
+ results.update(self.common_physics_step(batch, self._internal["encoded"]))
with torch.no_grad():
- scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5)
- final_loss = results['mse_loss']+scale*results['physics_loss']
- results['loss'] = final_loss
+ scale = (self.psf * results["mse_loss"]) / (results["physics_loss"] + 1e-5)
+ final_loss = results["mse_loss"] + scale * results["physics_loss"]
+ results["loss"] = final_loss
return results
def valid_step(self, batch):
- '''
+ """
This method overrides :func:`Trainer.valid_step ` and adds an additional 'Physics_loss' term.
Differently to :func:`train_step ` this method sums the logs of mse_loss and physics_loss ``final_loss = torch.log(results['mse_loss'])+scale*torch.log(results['physics_loss'])``
@@ -121,15 +141,17 @@ def valid_step(self, batch):
:returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined.
:rtype: dict
- '''
+ """
results = self.common_step(batch)
- results.update(self.common_physics_step(batch, self._internal['encoded']))
+ results.update(self.common_physics_step(batch, self._internal["encoded"]))
# scale = (self.psf*results['mse_loss'])/(results['physics_loss'] +1e-5)
- final_loss = torch.log(results['mse_loss'])+self.psf*torch.log(results['physics_loss'])
- results['loss'] = final_loss
+ final_loss = torch.log(results["mse_loss"]) + self.psf * torch.log(
+ results["physics_loss"]
+ )
+ results["loss"] = final_loss
return results
-if __name__=='__main__':
+if __name__ == "__main__":
pass
From 8eaa644c1732d7ce905d0760b787c54167b015a2 Mon Sep 17 00:00:00 2001
From: gwirn <71886945+gwirn@users.noreply.github.com>
Date: Fri, 15 Nov 2024 10:59:01 +0100
Subject: [PATCH 2/2] new scaling for physics loss
---
.../trainers/openmm_physics_trainer.py | 24 ++++++++++++-------
src/molearn/trainers/trainer.py | 16 +++++++++++++
2 files changed, 31 insertions(+), 9 deletions(-)
diff --git a/src/molearn/trainers/openmm_physics_trainer.py b/src/molearn/trainers/openmm_physics_trainer.py
index 1398988..e658265 100644
--- a/src/molearn/trainers/openmm_physics_trainer.py
+++ b/src/molearn/trainers/openmm_physics_trainer.py
@@ -41,7 +41,7 @@ def prepare_physics(
physics_scaling_factor=0.1,
clamp_threshold=1e8,
clamp=False,
- start_physics_at=0,
+ start_physics_at=10,
xml_file=None,
soft_NB=True,
**kwargs,
@@ -53,11 +53,11 @@ def prepare_physics(
:param float physics_scaling_factor: scaling factor saved to ``self.psf`` that is used in :func:`train_step `. Defaults to 0.1
:param float clamp_threshold: if ``clamp=True`` is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8
:param bool clamp: Whether to clamp the forces. Defaults to False
- :param int start_physics_at: As of yet unused parameter saved as ``self.start_physics_at = start_physics_at``. Default: 0
+ :param int start_physics_at: At which epoch the physics loss will be added to the loss. Default: 10
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy `
"""
- if xml_file is None and soft_NB == True:
+ if xml_file is None and soft_NB:
print("using soft nonbonded forces by default")
from molearn.utils import random_string
@@ -110,10 +110,6 @@ def common_physics_step(self, batch, latent):
def train_step(self, batch):
"""
This method overrides :func:`Trainer.train_step ` and adds an additional 'Physics_loss' term.
-
- Mse_loss and physics loss are summed (``Mse_loss + scale*physics_loss``)with a scaling factor ``self.psf*mse_loss/Physics_loss``. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a ``torch.no_grad`` context manager the gradients are not computed.
- This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.
-
Called from :func:`Trainer.train_epoch `.
:param torch.Tensor batch: tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
@@ -123,9 +119,19 @@ def train_step(self, batch):
results = self.common_step(batch)
results.update(self.common_physics_step(batch, self._internal["encoded"]))
+
with torch.no_grad():
- scale = (self.psf * results["mse_loss"]) / (results["physics_loss"] + 1e-5)
- final_loss = results["mse_loss"] + scale * results["physics_loss"]
+ if self.epoch == self.start_physics_at:
+ self.phy_scale = self._get_scale(
+ results["mse_loss"],
+ results["physics_loss"],
+ self.psf,
+ )
+ if self.epoch >= self.start_physics_at:
+ final_loss = results["mse_loss"] + self.phy_scale * results["physics_loss"]
+ else:
+ final_loss = results["mse_loss"]
+
results["loss"] = final_loss
return results
diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py
index e66c849..1df3ff9 100644
--- a/src/molearn/trainers/trainer.py
+++ b/src/molearn/trainers/trainer.py
@@ -1,6 +1,7 @@
import os
import glob
import shutil
+import math
import numpy as np
import time
import torch
@@ -541,6 +542,21 @@ def get_repeat(self, checkpoint_folder):
"Something went wrong, you surely havnt done 1000 repeats?"
)
+ def _get_scale(
+ self, cur_mse_loss: float, loss_to_scale: float, scale_scale: float = 1.0
+ ):
+ """
+ get a scaling factor to scale a loss to be in the same order of magnitude like `cur_mse_loss`
+
+ :param float cur_mse_loss: the mse loss of the current epoch
+ :param float loss_to_scale: the loss that should be scaled to be in the same order of magnitude as the `cur_mse_loss`
+ :param float scale_scale: scale to in-/ decrease the scale further
+ :return float scaling_factor: the calculated scaling factor for the `loss_to_scale`
+ """
+ mag_mse = math.floor(math.log10(cur_mse_loss if cur_mse_loss > 0 else 1e-32))
+ mag_phy = math.floor(math.log10(loss_to_scale if loss_to_scale > 0 else 1e-32))
+ return 10 ** (mag_mse - mag_phy) * scale_scale
+
if __name__ == "__main__":
pass