diff --git a/src/somd2/runner/_repex.py b/src/somd2/runner/_repex.py index 94af669..e60bd2c 100644 --- a/src/somd2/runner/_repex.py +++ b/src/somd2/runner/_repex.py @@ -162,7 +162,6 @@ def set_states(self, states): states: np.ndarray The new states. """ - self._old_states = self._states self._states = states def mix_states(self): @@ -170,13 +169,11 @@ def mix_states(self): Mix the states of the dynamics objects. """ # Mix the states. - for i, (old_state, new_state) in enumerate(zip(self._old_states, self._states)): + for i, state in enumerate(self._states): # The state has changed. - if old_state != new_state: - _logger.debug( - f"Replica {i} changed state from {old_state} to {new_state}" - ) - self._dynamics[i]._d._omm_mols.setState(self._openmm_states[new_state]) + if i != state: + _logger.debug(f"Replica {i} changed state to {state}") + self._dynamics[i]._d._omm_mols.setState(self._openmm_states[state]) class RepexRunner(_RunnerBase): @@ -394,7 +391,6 @@ def run(self): self._mix_replicas( self._config.num_lambda, energy_matrix, - self._dynamics_cache._states, ) ) self._dynamics_cache.mix_states() @@ -567,7 +563,7 @@ def _assemble_results(self, results): @staticmethod @_njit - def _mix_replicas(num_replicas, energy_matrix, states): + def _mix_replicas(num_replicas, energy_matrix): """ Mix the replicas. @@ -577,15 +573,9 @@ def _mix_replicas(num_replicas, energy_matrix, states): num_replicas: int The number of replicas. - num_attempts: int - The number of attempts to make. - energy_matrix: np.ndarray The energy matrix for the replicas. - states: np.ndarray - The current state for each replica. - Returns ------- @@ -593,9 +583,12 @@ def _mix_replicas(num_replicas, energy_matrix, states): The new states. """ - # Copy the states. - states = states.copy() + # Adapted from OpenMMTools: https://github.com/choderalab/openmmtools + + # Set the states to the initial order. + states = _np.arange(num_replicas) + # Attempt swaps. for swap in range(num_replicas**3): # Choose two replicas to swap. replica_i = _np.random.randint(num_replicas) @@ -614,7 +607,7 @@ def _mix_replicas(num_replicas, energy_matrix, states): # Compute the log probability of the swap. log_p_swap = -(energy_ij + energy_ji) + energy_ii + energy_jj - # Accept or reject the swap. + # Accept the swap and update the states. if log_p_swap >= 0 or _np.random.rand() < _np.exp(log_p_swap): states[replica_i] = state_j states[replica_j] = state_i