Skip to content

Commit

Permalink
new scaling for physics loss
Browse files Browse the repository at this point in the history
  • Loading branch information
gwirn committed Nov 15, 2024
1 parent c28e8cb commit 8eaa644
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/molearn/trainers/openmm_physics_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def prepare_physics(
physics_scaling_factor=0.1,
clamp_threshold=1e8,
clamp=False,
start_physics_at=0,
start_physics_at=10,
xml_file=None,
soft_NB=True,
**kwargs,
Expand All @@ -53,11 +53,11 @@ def prepare_physics(
:param float physics_scaling_factor: scaling factor saved to ``self.psf`` that is used in :func:`train_step <molearn.trainers.OpenMM_Physics_Trainer.train_step>`. Defaults to 0.1
:param float clamp_threshold: if ``clamp=True`` is passed then forces will be clamped between -clamp_threshold and clamp_threshold. Default: 1e-8
:param bool clamp: Whether to clamp the forces. Defaults to False
:param int start_physics_at: As of yet unused parameter saved as ``self.start_physics_at = start_physics_at``. Default: 0
:param int start_physics_at: At which epoch the physics loss will be added to the loss. Default: 10
:param \*\*kwargs: All aditional kwargs will be passed to :func:`openmm_energy <molearn.loss_functions.openmm_energy>`
"""
if xml_file is None and soft_NB == True:
if xml_file is None and soft_NB:
print("using soft nonbonded forces by default")
from molearn.utils import random_string

Expand Down Expand Up @@ -110,10 +110,6 @@ def common_physics_step(self, batch, latent):
def train_step(self, batch):
"""
This method overrides :func:`Trainer.train_step <molearn.trainers.Trainer.train_step>` and adds an additional 'Physics_loss' term.
Mse_loss and physics loss are summed (``Mse_loss + scale*physics_loss``)with a scaling factor ``self.psf*mse_loss/Physics_loss``. Mathematically this cancels out the physics_loss and the final loss is (1+self.psf)*mse_loss. However because the scaling factor is calculated within a ``torch.no_grad`` context manager the gradients are not computed.
This is essentially the same as scaling the physics_loss with any arbitary scaling factor but in this case simply happens to be exactly proportional to the ration of Mse_loss and physics_loss in every step.
Called from :func:`Trainer.train_epoch <molearn.trainers.Trainer.train_epoch>`.
:param torch.Tensor batch: tensor shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``.
Expand All @@ -123,9 +119,19 @@ def train_step(self, batch):

results = self.common_step(batch)
results.update(self.common_physics_step(batch, self._internal["encoded"]))

with torch.no_grad():
scale = (self.psf * results["mse_loss"]) / (results["physics_loss"] + 1e-5)
final_loss = results["mse_loss"] + scale * results["physics_loss"]
if self.epoch == self.start_physics_at:
self.phy_scale = self._get_scale(
results["mse_loss"],
results["physics_loss"],
self.psf,
)
if self.epoch >= self.start_physics_at:
final_loss = results["mse_loss"] + self.phy_scale * results["physics_loss"]
else:
final_loss = results["mse_loss"]

results["loss"] = final_loss
return results

Expand Down
16 changes: 16 additions & 0 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import glob
import shutil
import math
import numpy as np
import time
import torch
Expand Down Expand Up @@ -541,6 +542,21 @@ def get_repeat(self, checkpoint_folder):
"Something went wrong, you surely havnt done 1000 repeats?"
)

def _get_scale(
self, cur_mse_loss: float, loss_to_scale: float, scale_scale: float = 1.0
):
"""
get a scaling factor to scale a loss to be in the same order of magnitude like `cur_mse_loss`
:param float cur_mse_loss: the mse loss of the current epoch
:param float loss_to_scale: the loss that should be scaled to be in the same order of magnitude as the `cur_mse_loss`
:param float scale_scale: scale to in-/ decrease the scale further
:return float scaling_factor: the calculated scaling factor for the `loss_to_scale`
"""
mag_mse = math.floor(math.log10(cur_mse_loss if cur_mse_loss > 0 else 1e-32))
mag_phy = math.floor(math.log10(loss_to_scale if loss_to_scale > 0 else 1e-32))
return 10 ** (mag_mse - mag_phy) * scale_scale


if __name__ == "__main__":
pass

0 comments on commit 8eaa644

Please sign in to comment.