Skip to content

Commit

Permalink
Add timeout to OpenMM minimiser. [closes #230]
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Oct 2, 2024
1 parent 4fcce0f commit af24297
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 4 deletions.
14 changes: 14 additions & 0 deletions src/sire/mol/_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def run_minimisation(
starting_k: float = 100.0,
ratchet_scale: float = 2.0,
max_constraint_error: float = 0.001,
timeout: str = "60s",
):
"""
Internal method that runs minimisation on the molecules.
Expand Down Expand Up @@ -619,12 +620,24 @@ def run_minimisation(
- starting_k (float): The starting value of k for the minimisation
- ratchet_scale (float): The amount to scale k at each ratchet
- max_constraint_error (float): The maximum error in the constraint in nm
- timeout (float): The maximum time to run the minimisation for in seconds.
A value of <=0 will disable the timeout.
"""
from ..legacy.Convert import minimise_openmm_context

if max_iterations <= 0:
max_iterations = 0

try:
from ..units import second
from .. import u

timeout = u(timeout)
if not timeout.has_same_units(second):
raise ValueError("'timeout' must have units of time")
except:
raise ValueError("Unable to parse 'timeout' as a time")

self._minimisation_log = minimise_openmm_context(
self._omm_mols,
tolerance=tolerance,
Expand All @@ -635,6 +648,7 @@ def run_minimisation(
starting_k=starting_k,
ratchet_scale=ratchet_scale,
max_constraint_error=max_constraint_error,
timeout=timeout.to(second),
)

def _rebuild_and_minimise(self):
Expand Down
4 changes: 4 additions & 0 deletions src/sire/mol/_minimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def run(
starting_k: float = 400.0,
ratchet_scale: float = 10.0,
max_constraint_error: float = 0.001,
timeout: str = "60s",
):
"""
Internal method that runs minimisation on the molecules.
Expand Down Expand Up @@ -129,6 +130,8 @@ def run(
- starting_k (float): The starting value of k for the minimisation
- ratchet_scale (float): The amount to scale k at each ratchet
- max_constraint_error (float): The maximum error in the constraint in nm
- timeout (float): The maximum time to run the minimisation for in seconds.
A value of <=0 will disable the timeout.
"""
if not self._d.is_null():
self._d.run_minimisation(
Expand All @@ -140,6 +143,7 @@ def run(
starting_k=starting_k,
ratchet_scale=ratchet_scale,
max_constraint_error=max_constraint_error,
timeout=timeout,
)

return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2143,13 +2143,13 @@ void register_free_functions(){

{ //::SireOpenMM::minimise_openmm_context

typedef ::QString ( *minimise_openmm_context_function_type )( ::OpenMM::Context &,double,int,int,int,int,double,double,double );
typedef ::QString ( *minimise_openmm_context_function_type )( ::OpenMM::Context &,double,int,int,int,int,double,double,double,double );
minimise_openmm_context_function_type minimise_openmm_context_function_value( &::SireOpenMM::minimise_openmm_context );

bp::def(
"minimise_openmm_context"
, minimise_openmm_context_function_value
, ( bp::arg("context"), bp::arg("tolerance")=10., bp::arg("max_iterations")=(int)(-1), bp::arg("max_restarts")=(int)(10), bp::arg("max_ratchets")=(int)(20), bp::arg("ratchet_frequency")=(int)(500), bp::arg("starting_k")=100., bp::arg("ratchet_scale")=2., bp::arg("max_constraint_error")=0.01 )
, ( bp::arg("context"), bp::arg("tolerance")=10., bp::arg("max_iterations")=(int)(-1), bp::arg("max_restarts")=(int)(10), bp::arg("max_ratchets")=(int)(20), bp::arg("ratchet_frequency")=(int)(500), bp::arg("starting_k")=100., bp::arg("ratchet_scale")=2., bp::arg("max_constraint_error")=0.01, bp::arg("timeout")=60. )
, "This is a minimiser heavily inspired by the\nLocalEnergyMinimizer included in OpenMM. This is re-written\nfor sire to;\n\n1. Better integrate minimisation into the sire progress\nmonitoring interupting framework.\n2. Avoid errors caused by OpenMM switching from the desired\ncontext to the CPU context, thus triggering spurious exceptions\nrelated to exclusions exceptions not matching\n\nThis exposes more controls from the underlying minimisation\nlibrary, and also logs events and progress, which is returned\nas a string.\n\nThis raises an exception if minimisation fails.\n" );

}
Expand Down
34 changes: 33 additions & 1 deletion wrapper/Convert/SireOpenMM/openmmminimise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

// COPIED FROM SO POST - https://stackoverflow.com/questions/570669/checking-if-a-double-or-float-is-nan-in-c

#include <chrono>
#include <cmath> // std::isnan, std::fpclassify
#include <iostream>
#include <iomanip> // std::setw
Expand Down Expand Up @@ -616,13 +617,18 @@ namespace SireOpenMM
int max_restarts, int max_ratchets,
int ratchet_frequency,
double starting_k, double ratchet_scale,
double max_constraint_error)
double max_constraint_error, double timeout)
{
if (max_iterations < 0)
{
max_iterations = std::numeric_limits<int>::max();
}

if (timeout <= 0)
{
timeout = std::numeric_limits<double>::max();
}

auto gil = SireBase::release_gil();

const OpenMM::System &system = context.getSystem();
Expand Down Expand Up @@ -650,6 +656,7 @@ namespace SireOpenMM

data.addLog(QString("Minimising with a tolerance of %1").arg(tolerance));
data.addLog(QString("Minimising with constraint tolerance %1").arg(working_constraint_tol));
data.addLog(QString("Minimising with a timeout of %1 seconds").arg(timeout));
data.addLog(QString("Minimising with k = %1").arg(k));
data.addLog(QString("Minimising with %1 particles").arg(num_particles));
data.addLog(QString("Minimising with a maximum of %1 iterations").arg(max_iterations));
Expand Down Expand Up @@ -679,13 +686,26 @@ namespace SireOpenMM
int max_linesearch = 100;
const int max_linesearch_delta = 100;

// Store the starting time.
auto start_time = std::chrono::high_resolution_clock::now();

while (data.getIteration() < data.getMaxIterations())
{
if (not is_success)
{
// try one more time with the real starting positions
if (not have_hard_reset)
{
// Check the current time and see if we've exceeded the timeout.
auto current_time = std::chrono::high_resolution_clock::now();
auto elapsed_time = std::chrono::duration_cast<std::chrono::seconds>(current_time - start_time).count();

if (elapsed_time > timeout)
{
data.addLog("Minimisation timed out!");
break;
}

data.hardReset();

context.setPositions(starting_pos);
Expand All @@ -709,6 +729,7 @@ namespace SireOpenMM
}
}


data.addLog(QString("Minimisation loop - %1 steps from %2").arg(data.getIteration()).arg(data.getMaxIterations()));

try
Expand Down Expand Up @@ -762,6 +783,17 @@ namespace SireOpenMM
// Repeatedly minimize, steadily increasing the strength of the springs until all constraints are satisfied.
while (data.getIteration() < data.getMaxIterations())
{
// Check the current time and see if we've exceeded the timeout.
auto current_time = std::chrono::high_resolution_clock::now();
auto elapsed_time = std::chrono::duration_cast<std::chrono::seconds>(current_time - start_time).count();

if (elapsed_time > timeout)
{
data.addLog("Minimisation timed out!");
is_success = false;
break;
}

param.max_iterations = data.getMaxIterations() - data.getIteration();
lbfgsfloatval_t fx; // final energy
auto last_it = data.getIteration();
Expand Down
3 changes: 2 additions & 1 deletion wrapper/Convert/SireOpenMM/openmmminimise.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace SireOpenMM
int ratchet_frequency = 500,
double starting_k = 100.0,
double ratchet_scale = 2.0,
double max_constraint_error = 0.01);
double max_constraint_error = 0.01,
double timeout = 60.0);

}

Expand Down
2 changes: 2 additions & 0 deletions wrapper/Convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def minimise_openmm_context(
starting_k: float = 100.0,
ratchet_scale: float = 2.0,
max_constraint_error: float = 0.01,
timeout: str = "60s",
):
return _minimise_openmm_context(
context,
Expand All @@ -503,6 +504,7 @@ def minimise_openmm_context(
starting_k=starting_k,
ratchet_scale=ratchet_scale,
max_constraint_error=max_constraint_error,
timeout=timeout,
)

except Exception as e:
Expand Down

0 comments on commit af24297

Please sign in to comment.