diff --git a/devtools/data/gen-serialized-results.py b/devtools/data/gen-serialized-results.py index bd74a01f5..e402b8af5 100644 --- a/devtools/data/gen-serialized-results.py +++ b/devtools/data/gen-serialized-results.py @@ -91,8 +91,13 @@ def generate_md_json(smc): def generate_ahfe_json(smc): settings = AbsoluteSolvationProtocol.default_settings() + settings.solvent_equil_simulation_settings.equilibration_length_nvt = 10 * unit.picosecond + settings.solvent_equil_simulation_settings.equilibration_length = 10 * unit.picosecond + settings.solvent_equil_simulation_settings.production_length = 10 * unit.picosecond settings.solvent_simulation_settings.equilibration_length = 10 * unit.picosecond settings.solvent_simulation_settings.production_length = 500 * unit.picosecond + settings.vacuum_equil_simulation_settings.equilibration_length = 10 * unit.picosecond + settings.vacuum_equil_simulation_settings.production_length = 10 * unit.picosecond settings.vacuum_simulation_settings.equilibration_length = 10 * unit.picosecond settings.vacuum_simulation_settings.production_length = 1000 * unit.picosecond settings.lambda_settings.lambda_elec = [0.0, 0.25, 0.5, 0.75, 1.0, 1.0, diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 1761a4dee..b698b3a35 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -60,6 +60,7 @@ ThermoSettings, OpenFFPartialChargeSettings, ) from openfe.protocols.openmm_rfe._rfe_utils import compute +from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit from ..openmm_utils import ( settings_validation, system_creation, multistate_analysis, charge_generation @@ -152,37 +153,113 @@ def _get_alchemical_indices(omm_top: openmm.Topology, return atom_ids - @staticmethod - def _pre_minimize(system: openmm.System, - positions: omm_unit.Quantity) -> npt.NDArray: + def _pre_equilibrate( + self, + system: openmm.System, + topology: openmm.app.Topology, + positions: omm_unit.Quantity, + settings: dict[str, SettingsBaseModel], + dry: bool + ) -> omm_unit.Quantity: """ - Short CPU minization of System to avoid GPU NaNs + Run a non-alchemical equilibration to get a stable system. Parameters ---------- system : openmm.System - An OpenMM System to minimize. - positionns : openmm.unit.Quantity + An OpenMM System to equilibrate. + topology : openmm.app.Topology + OpenMM Topology of the System. + positions : openmm.unit.Quantity Initial positions for the system. + settings : dict[str, SettingsBaseModel] + A dictionary of settings objects. Expects the + following entries: + * `engine_settings` + * `thermo_settings` + * `integrator_settings` + * `equil_simulation_settings` + * `equil_output_settings` + dry: bool + Whether or not this is a dry run. Returns ------- - minimized_positions : npt.NDArray - Minimized positions + equilibrated_positions : npt.NDArray + Equilibrated system positions """ - integrator = openmm.VerletIntegrator(0.001) - context = openmm.Context( - system, integrator, - openmm.Platform.getPlatformByName('CPU'), + # Prep the simulation object + platform = compute.get_openmm_platform( + settings['engine_settings'].compute_platform ) - context.setPositions(positions) - # Do a quick 100 steps minimization, usually avoids NaNs - openmm.LocalEnergyMinimizer.minimize( - context, maxIterations=100 + + integrator = openmm.LangevinMiddleIntegrator( + to_openmm(settings['thermo_settings'].temperature), + to_openmm(settings['integrator_settings'].langevin_collision_rate), + to_openmm(settings['integrator_settings'].timestep), ) - state = context.getState(getPositions=True) - minimized_positions = state.getPositions(asNumpy=True) - return minimized_positions + + simulation = openmm.app.Simulation( + topology=topology, + system=system, + integrator=integrator, + platform=platform, + ) + + # Get the necessary number of steps + if settings['equil_simulation_settings'].equilibration_length_nvt is not None: + equil_steps_nvt = settings_validation.get_simsteps( + sim_length=settings[ + 'equil_simulation_settings'].equilibration_length_nvt, + timestep=settings['integrator_settings'].timestep, + mc_steps=1, + ) + else: + equil_steps_nvt = None + + equil_steps_npt = settings_validation.get_simsteps( + sim_length=settings['equil_simulation_settings'].equilibration_length, + timestep=settings['integrator_settings'].timestep, + mc_steps=1, + ) + + prod_steps_npt = settings_validation.get_simsteps( + sim_length=settings['equil_simulation_settings'].production_length, + timestep=settings['integrator_settings'].timestep, + mc_steps=1, + ) + + if self.verbose: + logger.info("running non-alchemical equilibration MD") + + # Don't do anything if we're doing a dry run + if dry: + return positions + + # Use the _run_MD method from the PlainMDProtocolUnit + # Should in-place modify the simulation + PlainMDProtocolUnit._run_MD( + simulation=simulation, + positions=positions, + simulation_settings=settings['equil_simulation_settings'], + output_settings=settings['equil_output_settings'], + temperature=settings['thermo_settings'].temperature, + barostat_frequency=settings['integrator_settings'].barostat_frequency, + timestep=settings['integrator_settings'].timestep, + equil_steps_nvt=equil_steps_nvt, + equil_steps_npt=equil_steps_npt, + prod_steps=prod_steps_npt, + verbose=self.verbose, + shared_basepath=self.shared_basepath, + ) + + state = simulation.context.getState(getPositions=True) + equilibrated_positions = state.getPositions(asNumpy=True) + + # cautiously delete out contexts & integrator + del simulation.context, integrator + + return equilibrated_positions def _prepare( self, verbose: bool, @@ -241,6 +318,8 @@ def _handle_settings(self): * integrator_settings : IntegratorSettings * simulation_settings : MultiStateSimulationSettings * output_settings: OutputSettings + * equil_simulation_settings: MDSimulationSettings + * equil_output_settings: MDOutputSettings Settings may change depending on what type of simulation you are running. Cherry pick them and return them to be available later on. @@ -914,8 +993,10 @@ def run(self, dry=False, verbose=True, system_modeller, system_generator, list(smc_comps.values()) ) - # 6. Pre-minimize System (Test + Avoid NaNs) - positions = self._pre_minimize(omm_system, positions) + # 6. Pre-equilbrate System (Test + Avoid NaNs + get stable system) + positions = self._pre_equilibrate( + omm_system, omm_topology, positions, settings, dry + ) # 7. Get lambdas lambdas = self._get_lambda_schedule(settings) diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/openfe/protocols/openmm_afe/equil_afe_settings.py index 9f1df2435..cfda639d5 100644 --- a/openfe/protocols/openmm_afe/equil_afe_settings.py +++ b/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -28,6 +28,8 @@ IntegratorSettings, OpenFFPartialChargeSettings, OutputSettings, + MDSimulationSettings, + MDOutputSettings, ) import numpy as np @@ -173,20 +175,41 @@ def must_be_positive(cls, v): """ # Simulation run settings + vacuum_equil_simulation_settings: MDSimulationSettings + """ + Pre-alchemical vacuum simulation control settings. + + Notes + ----- + The `NVT` equilibration should be set to 0 * unit.nanosecond + as it will not be run. + """ vacuum_simulation_settings: MultiStateSimulationSettings """ Simulation control settings, including simulation lengths for the vacuum transformation. """ + solvent_equil_simulation_settings: MDSimulationSettings + """ + Pre-alchemical solvent simulation control settings. + """ solvent_simulation_settings: MultiStateSimulationSettings """ Simulation control settings, including simulation lengths for the solvent transformation. """ + vacuum_equil_output_settings: MDOutputSettings + """ + Simulation output settings for the vacuum non-alchemical equilibration. + """ vacuum_output_settings: OutputSettings """ Simulation output settings for the vacuum transformation. """ + solvent_equil_output_settings: MDOutputSettings + """ + Simulation output settings for the solvent non-alchemical equilibration. + """ solvent_output_settings: OutputSettings """ Simulation output settings for the solvent transformation. diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index e13fb68e4..f00bc8b46 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -51,6 +51,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import ( AbsoluteSolvationSettings, OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, + MDSimulationSettings, MDOutputSettings, MultiStateSimulationSettings, OpenMMEngineSettings, IntegratorSettings, OutputSettings, OpenFFPartialChargeSettings, @@ -421,6 +422,17 @@ def _default_settings(cls): vacuum_engine_settings=OpenMMEngineSettings(), solvent_engine_settings=OpenMMEngineSettings(), integrator_settings=IntegratorSettings(), + solvent_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.1 * unit.nanosecond, + equilibration_length=0.2 * unit.nanosecond, + production_length=0.5 * unit.nanosecond, + ), + solvent_equil_output_settings=MDOutputSettings( + equil_nvt_structure='equil_nvt_structure.pdb', + equil_npt_structure='equil_npt_structure.pdb', + production_trajectory_filename='production_equil.xtc', + log_output='equil_simulation.log', + ), solvent_simulation_settings=MultiStateSimulationSettings( n_replicas=14, equilibration_length=1.0 * unit.nanosecond, @@ -430,6 +442,17 @@ def _default_settings(cls): output_filename='solvent.nc', checkpoint_storage_filename='solvent_checkpoint.nc', ), + vacuum_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=None, + equilibration_length=0.2 * unit.nanosecond, + production_length=0.5 * unit.nanosecond, + ), + vacuum_equil_output_settings=MDOutputSettings( + equil_nvt_structure=None, + equil_npt_structure='equil_structure.pdb', + production_trajectory_filename='production_equil.xtc', + log_output='equil_simulation.log', + ), vacuum_simulation_settings=MultiStateSimulationSettings( n_replicas=14, equilibration_length=0.5 * unit.nanosecond, @@ -636,6 +659,13 @@ def _create( "passed") raise ValueError(errmsg) + # Check vacuum equilibration MD settings is 0 ns + nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt + if nvt_time is not None: + if not np.allclose(nvt_time, 0 * unit.nanosecond): + errmsg = "NVT equilibration cannot be run in vacuum simulation" + raise ValueError(errmsg) + # Get the name of the alchemical species alchname = alchem_comps['stateA'][0].name @@ -749,6 +779,8 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * lambda_settings : LambdaSettings * engine_settings : OpenMMEngineSettings * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings * simulation_settings : SimulationSettings * output_settings: OutputSettings """ @@ -763,6 +795,8 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: settings['lambda_settings'] = prot_settings.lambda_settings settings['engine_settings'] = prot_settings.vacuum_engine_settings settings['integrator_settings'] = prot_settings.integrator_settings + settings['equil_simulation_settings'] = prot_settings.vacuum_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.vacuum_equil_output_settings settings['simulation_settings'] = prot_settings.vacuum_simulation_settings settings['output_settings'] = prot_settings.vacuum_output_settings @@ -834,6 +868,8 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * lambda_settings : LambdaSettings * engine_settings : OpenMMEngineSettings * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings * simulation_settings : MultiStateSimulationSettings * output_settings: OutputSettings """ @@ -848,6 +884,8 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: settings['lambda_settings'] = prot_settings.lambda_settings settings['engine_settings'] = prot_settings.solvent_engine_settings settings['integrator_settings'] = prot_settings.integrator_settings + settings['equil_simulation_settings'] = prot_settings.solvent_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.solvent_equil_output_settings settings['simulation_settings'] = prot_settings.solvent_simulation_settings settings['output_settings'] = prot_settings.solvent_output_settings diff --git a/openfe/protocols/openmm_md/plain_md_methods.py b/openfe/protocols/openmm_md/plain_md_methods.py index 92df9dcb9..daff51553 100644 --- a/openfe/protocols/openmm_md/plain_md_methods.py +++ b/openfe/protocols/openmm_md/plain_md_methods.py @@ -257,7 +257,7 @@ def _run_MD(simulation: openmm.app.Simulation, temperature: unit.Quantity, barostat_frequency: unit.Quantity, timestep: unit.Quantity, - equil_steps_nvt: int, + equil_steps_nvt: Optional[int], equil_steps_npt: int, prod_steps: int, verbose=True, @@ -283,8 +283,9 @@ def _run_MD(simulation: openmm.app.Simulation, Frequency for the barostat timestep: FloatQuantity["femtosecond"] Simulation integration timestep - equil_steps_nvt: int + equil_steps_nvt: Optional[int] number of steps for NVT equilibration + if None, no NVT equilibration will be performed equil_steps_npt: int number of steps for NPT equilibration prod_steps: int @@ -327,38 +328,39 @@ def _run_MD(simulation: openmm.app.Simulation, ) # equilibrate # NVT equilibration - - if verbose: - logger.info("Running NVT equilibration") - - # Set barostat frequency to zero for NVT - for x in simulation.context.getSystem().getForces(): - if x.getName() == 'MonteCarloBarostat': - x.setFrequency(0) - - simulation.context.setVelocitiesToTemperature( - to_openmm(temperature)) - - t0 = time.time() - simulation.step(equil_steps_nvt) - t1 = time.time() - if verbose: - logger.info( - f"Completed NVT equilibration in {t1 - t0} seconds") - - # Save last frame NVT equilibration - positions = to_openmm( - from_openmm(simulation.context.getState( - getPositions=True, enforcePeriodicBox=False - ).getPositions())) - - traj = mdtraj.Trajectory( - positions[selection_indices, :], - mdtraj_top.subset(selection_indices), - ) - traj.save_pdb( - shared_basepath / output_settings.equil_NVT_structure - ) + if equil_steps_nvt: + if verbose: + logger.info("Running NVT equilibration") + + # Set barostat frequency to zero for NVT + for x in simulation.context.getSystem().getForces(): + if x.getName() == 'MonteCarloBarostat': + x.setFrequency(0) + + simulation.context.setVelocitiesToTemperature( + to_openmm(temperature)) + + t0 = time.time() + simulation.step(equil_steps_nvt) + t1 = time.time() + if verbose: + logger.info( + f"Completed NVT equilibration in {t1 - t0} seconds") + + # Save last frame NVT equilibration + positions = to_openmm( + from_openmm(simulation.context.getState( + getPositions=True, enforcePeriodicBox=False + ).getPositions())) + + traj = mdtraj.Trajectory( + positions[selection_indices, :], + mdtraj_top.subset(selection_indices), + ) + if output_settings.equil_nvt_structure is not None: + traj.save_pdb( + shared_basepath / output_settings.equil_nvt_structure + ) # NPT equilibration if verbose: @@ -388,9 +390,10 @@ def _run_MD(simulation: openmm.app.Simulation, positions[selection_indices, :], mdtraj_top.subset(selection_indices), ) - traj.save_pdb( - shared_basepath / output_settings.equil_NPT_structure - ) + if output_settings.equil_npt_structure is not None: + traj.save_pdb( + shared_basepath / output_settings.equil_npt_structure + ) # production if verbose: @@ -520,10 +523,14 @@ def run(self, *, dry=False, verbose=True, forcefield_settings.hydrogen_mass, timestep ) - equil_steps_nvt = settings_validation.get_simsteps( - sim_length=sim_settings.equilibration_length_nvt, - timestep=timestep, mc_steps=1, - ) + if sim_settings.equilibration_length_nvt is not None: + equil_steps_nvt = settings_validation.get_simsteps( + sim_length=sim_settings.equilibration_length_nvt, + timestep=timestep, mc_steps=1, + ) + else: + equil_steps_nvt = None + equil_steps_npt = settings_validation.get_simsteps( sim_length=sim_settings.equilibration_length, timestep=timestep, mc_steps=1, @@ -638,14 +645,18 @@ def run(self, *, dry=False, verbose=True, del integrator, simulation if not dry: # pragma: no-cover - return { + output = { 'system_pdb': shared_basepath / output_settings.preminimized_structure, 'minimized_pdb': shared_basepath / output_settings.minimized_structure, - 'nvt_equil_pdb': shared_basepath / output_settings.equil_NVT_structure, - 'npt_equil_pdb': shared_basepath / output_settings.equil_NPT_structure, 'nc': shared_basepath / output_settings.production_trajectory_filename, 'last_checkpoint': shared_basepath / output_settings.checkpoint_storage_filename, } + if output_settings.equil_nvt_structure: + output['nvt_equil_pdb'] = shared_basepath / output_settings.equil_nvt_structure + if output_settings.equil_npt_structure: + output['npt_equil_pdb'] = shared_basepath / output_settings.equil_npt_structure + + return output else: return {'debug': {'system': stateA_system}} diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index b07d57d6b..bbdb35d5f 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -188,9 +188,10 @@ class Config: """ constraint_tolerance = 1e-06 """Tolerance for the constraint solver. Default 1e-6.""" - barostat_frequency = 25 * unit.timestep # todo: IntQuantity + barostat_frequency: FloatQuantity['timestep'] = 25 * unit.timestep # todo: IntQuantity """ Frequency at which volume scaling changes should be attempted. + Note: The barostat frequency is ignored for gas-phase simulations. Default 25 * unit.timestep. """ remove_com: bool = False @@ -449,11 +450,12 @@ class MDSimulationSettings(SimulationSettings): class Config: arbitrary_types_allowed = True - equilibration_length_nvt: unit.Quantity + equilibration_length_nvt: Optional[FloatQuantity['nanosecond']] """ Length of the equilibration phase in the NVT ensemble in units of time. The total number of steps from this equilibration length (i.e. ``equilibration_length_nvt`` / :class:`IntegratorSettings.timestep`). + If None, no NVT equilibration will be performed. """ @@ -475,12 +477,12 @@ class Config: minimized_structure = 'minimized.pdb' """Path to the pdb file of the system after minimization. Only the specified atom subset is saved. Default 'minimized.pdb'.""" - equil_NVT_structure = 'equil_NVT.pdb' + equil_nvt_structure: Optional[str] = 'equil_nvt.pdb' """Path to the pdb file of the system after NVT equilibration. - Only the specified atom subset is saved. Default 'equil_NVT.pdb'.""" - equil_NPT_structure = 'equil_NPT.pdb' + Only the specified atom subset is saved. Default 'equil_nvt.pdb'.""" + equil_npt_structure: Optional[str] = 'equil_npt.pdb' """Path to the pdb file of the system after NPT equilibration. - Only the specified atom subset is saved. Default 'equil_NPT.pdb'.""" + Only the specified atom subset is saved. Default 'equil_npt.pdb'.""" log_output = 'simulation.log' """ Filename for writing the log of the MD simulation, including timesteps, diff --git a/openfe/protocols/openmm_utils/settings_validation.py b/openfe/protocols/openmm_utils/settings_validation.py index 0526f95f2..5df183934 100644 --- a/openfe/protocols/openmm_utils/settings_validation.py +++ b/openfe/protocols/openmm_utils/settings_validation.py @@ -160,11 +160,15 @@ def convert_checkpoint_interval_to_iterations( iterations : int The number of iterations per checkpoint. """ - return divmod_time_and_check( - checkpoint_interval, time_per_iteration, - "amount of time per checkpoint", - "amount of time per state MCM move attempt" - ) + iterations, rem = divmod_time(checkpoint_interval, time_per_iteration) + + if rem: + errmsg = (f"The amount of time per checkpoint {checkpoint_interval} " + "does not evenly divide by the amount of time per " + f"state MCMC move attempt {time_per_iteration}") + raise ValueError(errmsg) + + return iterations def convert_steps_per_iteration( diff --git a/openfe/tests/protocols/test_openmm_afe_slow.py b/openfe/tests/protocols/test_openmm_afe_slow.py index 020be4efa..4c19394aa 100644 --- a/openfe/tests/protocols/test_openmm_afe_slow.py +++ b/openfe/tests/protocols/test_openmm_afe_slow.py @@ -47,8 +47,13 @@ def test_openmm_run_engine(platform, s = openmm_afe.AbsoluteSolvationProtocol.default_settings() s.protocol_repeats = 1 s.solvent_output_settings.output_indices = "resname UNK" + s.vacuum_equil_simulation_settings.equilibration_length = 0.1 * unit.picosecond + s.vacuum_equil_simulation_settings.production_length = 0.1 * unit.picosecond s.vacuum_simulation_settings.equilibration_length = 0.1 * unit.picosecond s.vacuum_simulation_settings.production_length = 0.1 * unit.picosecond + s.solvent_equil_simulation_settings.equilibration_length_nvt = 0.1 * unit.picosecond + s.solvent_equil_simulation_settings.equilibration_length = 0.1 * unit.picosecond + s.solvent_equil_simulation_settings.production_length = 0.1 * unit.picosecond s.solvent_simulation_settings.equilibration_length = 0.1 * unit.picosecond s.solvent_simulation_settings.production_length = 0.1 * unit.picosecond s.vacuum_engine_settings.compute_platform = platform diff --git a/openfe/tests/protocols/test_solvation_afe_tokenization.py b/openfe/tests/protocols/test_solvation_afe_tokenization.py index 436930e10..94dd26c45 100644 --- a/openfe/tests/protocols/test_solvation_afe_tokenization.py +++ b/openfe/tests/protocols/test_solvation_afe_tokenization.py @@ -49,7 +49,7 @@ def protocol_result(afe_solv_transformation_json): class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin): cls = openmm_afe.AbsoluteSolvationProtocol - key = "AbsoluteSolvationProtocol-04f686419a0bf9568c0475b6317278cd" + key = "AbsoluteSolvationProtocol-38eae61f1138e3b44d15be9d03e0d57e" repr = f"<{key}>" @pytest.fixture()