diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index b018200..c61b71c 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -103,7 +103,9 @@ def __init__(self, system, lambdas, num_gpus, dynamics_kwargs): # Append the dynamics object. self._dynamics.append(dynamics) - _logger.info(f"Created dynamics object for lambda {lam:.5f} on device {device}") + _logger.info( + f"Created dynamics object for lambda {lam:.5f} on device {device}" + ) def get(self, index): """ @@ -261,6 +263,11 @@ def __init__(self, system, config): else: self._start_block = 0 + from threading import Lock + + # Create a lock to guard the dynamics cache. + self._lock = Lock() + def __str__(self): """Return a string representation of the object.""" return f"RepexRunner(system={self._system}, config={self._config})" @@ -329,31 +336,11 @@ def run(self): # Create the replica list. replica_list = list(range(self._config.num_lambda)) - # Minimise at each lambda value. + # Minimise at each lambda value. This is currently done in serial due to a + # limitation in OpenMM. if self._config.minimise: - # Run minimisation for each replica, making sure only each GPU is only - # oversubscribed by a factor of self._config.oversubscription_factor. - for i in range(num_batches): - with ThreadPoolExecutor() as executor: - try: - for result, index, exception in executor.map( - self._minimise, - replica_list[ - i - * self._num_gpus - * self._config.oversubscription_factor : (i + 1) - * self._num_gpus - * self._config.oversubscription_factor - ], - ): - if not result: - _logger.error( - f"Minimisation failed for {_lam_sym} = {self._lambda_values[index]:.5f}: {exception}" - ) - raise exception - except KeyboardInterrupt: - _logger.error("Minimisation cancelled. Exiting.") - exit(1) + for i in range(self._config.num_lambda): + self._minimise(i) # Current block number. block = 0 @@ -515,9 +502,10 @@ def _run_block( speed = dynamics.time_speed() # Checkpoint. - self._checkpoint( - system, index, block, speed, is_final_block=is_final_block - ) + with self._lock: + self._checkpoint( + system, index, block, speed, is_final_block=is_final_block + ) _logger.info( f"Finished block {block+1} of {self._start_block + num_blocks} "