diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 633ec884a..6c3a47ac7 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -31,6 +31,7 @@ from openmmtools import multistate from openmmtools.states import (SamplerState, ThermodynamicState, + GlobalParameterState, create_thermodynamic_state_protocol,) from openmmtools.alchemy import (AlchemicalRegion, AbsoluteAlchemicalFactory, AlchemicalState,) @@ -469,45 +470,70 @@ def _get_modeller( def _get_omm_objects( self, - system_modeller: app.Modeller, - system_generator: SystemGenerator, - smc_components: list[OFFMolecule], - ) -> tuple[app.Topology, openmm.unit.Quantity, openmm.System]: + settings: dict[str, SettingsBaseModel], + protein_component: Optional[ProteinComponent], + solvent_component: Optional[SolventComponent], + smc_components: dict[SmallMoleculeComponent, OFFMolecule], + ) -> tuple[ + app.Topology, + openmm.System, + openmm.unit.Quantity, + dict[str, npt.NDArray], + ]: """ Get the OpenMM Topology, Positions and System of the parameterised system. Parameters ---------- - system_modeller : app.Modeller - OpenMM Modeller object representing the system to be - parametrized. - system_generator : SystemGenerator - SystemGenerator object to create a System with. - smc_components : list[openff.toolkit.Molecule] - A list of openff Molecules to add to the system. + settings : dict[str, SettingsBaseModel] + Protocol settings + protein_component : Optional[ProteinComponent] + Protein component for the system. + solvent_component : Optional[SolventComponent] + Solvent component for the system. + smc_components : dict[str, OFFMolecule] + SmallMoleculeComponents defining ligands to be added to the system Returns ------- topology : app.Topology - Topology object describing the parameterized system + OpenMM Topology object describing the parameterized system. system : openmm.System - An OpenMM System of the alchemical system. - positionns : openmm.unit.Quantity + An non-alchemical OpenMM System of the simulated system. + positions : openmm.unit.Quantity Positions of the system. + comp_resids : dict[str, npt.NDArray] + A dictionary of residues for each component in the System. """ - topology = system_modeller.getTopology() + if self.verbose: + self.logger.info("Parameterizing system") + + system_generator = self._get_system_generator( + settings, solvent_component + ) + + modeller, comp_resids = self._get_modeller( + protein_component, + solvent_component, + smc_components, + system_generator, + settings['charge_settings'], + settings['solvation_settings'] + ) + + topology = modeller.getTopology() # roundtrip positions to remove vec3 issues - positions = to_openmm(from_openmm(system_modeller.getPositions())) + positions = to_openmm(from_openmm(modeller.getPositions())) # Block out oechem backend to avoid any issues with # smiles roundtripping between rdkit and oechem with without_oechem_backend(): system = system_generator.create_system( - system_modeller.topology, + modeller.topology, molecules=smc_components, ) - return topology, system, positions + return topology, system, positions, comp_resids def _get_lambda_schedule( self, settings: dict[str, SettingsBaseModel] @@ -533,13 +559,16 @@ def _get_lambda_schedule( lambda_elec = settings['lambda_settings'].lambda_elec lambda_vdw = settings['lambda_settings'].lambda_vdw + lambda_rest = settings['lambda_settings'].lambda_restraints # Reverse lambda schedule since in AbsoluteAlchemicalFactory 1 # means fully interacting, not stateB - lambda_elec = [1-x for x in lambda_elec] - lambda_vdw = [1-x for x in lambda_vdw] - lambdas['lambda_electrostatics'] = lambda_elec - lambdas['lambda_sterics'] = lambda_vdw + for name, schedule in [ + ('lambda_electrostatics', lambda_elec), + ('lambda_sterics', lambda_vdw), + ('lambda_restraints', lambda_rest), + ]: + lambdas[name] = [1-x for x in schedule] return lambdas @@ -547,7 +576,7 @@ def _add_restraints(self, system, topology, settings): """ Placeholder method to add restraints if necessary """ - return + return None, None, system def _get_alchemical_system( self, @@ -607,6 +636,7 @@ def _get_states( settings: dict[str, SettingsBaseModel], lambdas: dict[str, npt.NDArray], solvent_comp: Optional[SolventComponent], + restraint_state: Optional[GlobalParameterState], ) -> tuple[list[SamplerState], list[ThermodynamicState]]: """ Get a list of sampler and thermodynmic states from an @@ -624,6 +654,8 @@ def _get_states( A dictionary of lambda scales. solvent_comp : Optional[SolventComponent] The solvent component of the system, if there is one. + restraint_state : Optional[GlobalParameterState] + The restraint parameter control state, if there is one. Returns ------- @@ -641,9 +673,14 @@ def _get_states( if solvent_comp is not None: constants['pressure'] = ensure_quantity(pressure, 'openmm') + if restraint_state is not None: + composable_states = [alchemical_state, restraint_state] + else: + composable_states = [alchemical_state,] + cmp_states = create_thermodynamic_state_protocol( alchemical_system, protocol=lambdas, - constants=constants, composable_states=[alchemical_state], + constants=constants, composable_states=composable_states, ) sampler_state = SamplerState(positions=positions) @@ -873,6 +910,7 @@ def _run_simulation( sampler: multistate.MultiStateSampler, reporter: multistate.MultiStateReporter, settings: dict[str, SettingsBaseModel], + standard_state_corr: Optional[unit.Quantity] dry: bool ): """ @@ -886,6 +924,8 @@ def _run_simulation( The reporter associated with the sampler. settings : dict[str, SettingsBaseModel] The dictionary of settings for the protocol. + standard_state_corr : Optional[unit.Quantity] + The standard state correction, if available. dry : bool Whether or not to dry run the simulation @@ -944,7 +984,12 @@ def _run_simulation( analyzer.plot(filepath=self.shared_basepath, filename_prefix="") analyzer.close() - return analyzer.unit_results_dict + return_dict = analyzer.unit_results_dict + + if standard_state_corr is not None: + return_dict['standard_state_correction'] = standard_state_corr + + return return_dict else: # close reporter when you're done, prevent file handle clashes @@ -991,44 +1036,40 @@ def run(self, dry=False, verbose=True, # 2. Get settings settings = self._handle_settings() - # 3. Get system generator - system_generator = self._get_system_generator(settings, solv_comp) - - # 4. Get modeller - system_modeller, comp_resids = self._get_modeller( - prot_comp, solv_comp, smc_comps, system_generator, - settings['charge_settings'], - settings['solvation_settings'], + # 3. Get OpenMM topology, positions, and system + omm_topology, omm_system, position, comp_resids = self._get_omm_objects( + settings, prot_comps, solv_comps, smc_comps, ) - # 5. Get OpenMM topology, positions and system - omm_topology, omm_system, positions = self._get_omm_objects( - system_modeller, system_generator, list(smc_comps.values()) - ) - - # 6. Pre-equilbrate System (Test + Avoid NaNs + get stable system) + # 4. Pre-equilbrate System (Test + Avoid NaNs + get stable system) positions = self._pre_equilibrate( omm_system, omm_topology, positions, settings, dry ) - # 7. Get lambdas + # 5. Get lambdas lambdas = self._get_lambda_schedule(settings) - # 8. Add restraints - self._add_restraints(omm_system, omm_topology, settings) + # 6. Add restraints + restraint_parameter_state, standard_state_corr, omm_system = self._add_restraints( + omm_system, omm_topology, settings + ) - # 9. Get alchemical system + # 7. Get alchemical system alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( omm_topology, omm_system, comp_resids, alchem_comps ) - # 10. Get compound and sampler states + # 7. Get compound and sampler states sampler_states, cmp_states = self._get_states( - alchem_system, positions, settings, - lambdas, solv_comp + alchem_system, + positions, + settings, + lambdas, + solv_comp, + restraint_parameter_state, ) - # 11. Create the multistate reporter & create PDB + # 9. Create the multistate reporter & create PDB reporter = self._get_reporter( omm_topology, positions, settings['simulation_settings'], @@ -1037,19 +1078,19 @@ def run(self, dry=False, verbose=True, # Wrap in try/finally to avoid memory leak issues try: - # 12. Get context caches + # 10. Get context caches energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches( settings['forcefield_settings'], settings['engine_settings'] ) - # 13. Get integrator + # 11. Get integrator integrator = self._get_integrator( settings['integrator_settings'], settings['simulation_settings'], ) - # 14. Get sampler + # 12. Get sampler sampler = self._get_sampler( integrator, reporter, settings['simulation_settings'], settings['thermo_settings'], @@ -1057,9 +1098,13 @@ def run(self, dry=False, verbose=True, energy_ctx_cache, sampler_ctx_cache ) - # 15. Run simulation + # 13. Run simulation unit_result_dict = self._run_simulation( - sampler, reporter, settings, dry + sampler, + reporter, + settings, + standard_state_corr, + dry ) finally: diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/openfe/protocols/openmm_afe/equil_afe_settings.py index 1e45e007f..860a67a61 100644 --- a/openfe/protocols/openmm_afe/equil_afe_settings.py +++ b/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -30,6 +30,10 @@ MultiStateOutputSettings, MDSimulationSettings, MDOutputSettings, + BaseRestraintSettings, + HarmonicRestraintSettings, + FlatBottomRestraintSettings, + BoreschRestraintSettings, ) import numpy as np @@ -217,3 +221,106 @@ def must_be_positive(cls, v): including the partial charge assignment method, and the number of conformers used to generate the partial charges. """ + + +class AbsoluteBindingSettings(SettingsBaseModel): + """ + Configuration object for ``AbsoluteBindingPProtocol`` + + See Also + -------- + openfe.protocols.openmm_afe.AbsoluteBindingProtocol + """ + protocol_repeats: int + """ + The number of completely independent repeats of the entire sampling + process. The mean of the repeats defines the final estimate of FE + difference, while the variance between repeats is used as the uncertainty. + """ + + @validator('protocol_repeats') + def must_be_positive(cls, v): + if v <= 0: + errmsg = f"protocol_repeats must be a positive value, got {v}." + raise ValueError(errmsg) + return v + + forcefield_settings: OpenMMSystemGeneratorFFSettings + """Parameters to set up the force field with OpenMM Force Fields""" + thermo_settings: ThermoSettings + """Settings for thermodynamic parameters""" + + solvation_settings: OpenMMSolvationSettings + """Settings for solvating the system.""" + + # Alchemical settings + alchemical_settings: AlchemicalSettings + """ + Alchemical protocol settings. + """ + lambda_settings: LambdaSettings + """ + Settings for controlling the lambda schedule for the different components + (vdw, elec, restraints). + """ + + # MD Engine things + engine_settings: OpenMMEngineSettings + """ + Settings specific to the OpenMM engine, such as the compute platform. + """ + + # Sampling State defining things + integrator_settings: IntegratorSettings + """ + Settings for controlling the integrator, such as the timestep and + barostat settings. + """ + + # Simulation run settings + complex_equil_simulation_settings: MDSimulationSettings + """ + Pre-alchemical complex simulation control settings. + """ + complex_simulation_settings: MultiStateSimulationSettings + """ + Simulation control settings, including simulation lengths + for the complex transformation. + """ + solvent_equil_simulation_settings: MDSimulationSettings + """ + Pre-alchemical solvent simulation control settings. + """ + solvent_simulation_settings: MultiStateSimulationSettings + """ + Simulation control settings, including simulation lengths + for the solvent transformation. + """ + complex_equil_output_settings: MDOutputSettings + """ + Simulation output settings for the complex non-alchemical equilibration. + """ + complex_output_settings: MultiStateOutputSettings + """ + Simulation output settings for the complex transformation. + """ + solvent_equil_output_settings: MDOutputSettings + """ + Simulation output settings for the solvent non-alchemical equilibration. + """ + solvent_output_settings: MultiStateOutputSettings + """ + Simulation output settings for the solvent transformation. + """ + partial_charge_settings: OpenFFPartialChargeSettings + """ + Settings for controlling how to assign partial charges, + including the partial charge assignment method, and the + number of conformers used to generate the partial charges. + """ + restraint_settings: BaseRestraintSettings + """ + Settings controlling how restraints are added to the system in the + complex simulation. + """ + diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py new file mode 100644 index 000000000..3c1e63acb --- /dev/null +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -0,0 +1,968 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` +=============================================================================================================== + +This module implements the necessary methodology tooling to run calculate an +absolute solvation free energy using OpenMM tools and one of the following +alchemical sampling methods: + +* Hamiltonian Replica Exchange +* Self-adjusted mixture sampling +* Independent window sampling + +Current limitations +------------------- +* Disapearing molecules are only allowed in state A. Support for + appearing molecules will be added in due course. +* Only small molecules are allowed to act as alchemical molecules. + Alchemically changing protein or solvent components would induce + perturbations which are too large to be handled by this Protocol. + + +Acknowledgements +---------------- +* Originally based on hydration.py in + `espaloma_charge `_ + +""" +from __future__ import annotations + +import pathlib +import logging +import warnings +from collections import defaultdict +import gufe +from gufe.components import Component +import itertools +import numpy as np +import numpy.typing as npt +from openff.units import unit +from openmmtools import multistate +from openmmtools.state import ThermodynamicState, GlobalParameterState +from typing import Optional, Union +from typing import Any, Iterable +import uuid + +from gufe import ( + settings, + ChemicalSystem, SmallMoleculeComponent, + ProteinComponent, SolventComponent +) +from openfe.protocols.openmm_afe.equil_afe_settings import ( + AbsoluteSolvationSettings, + OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, + MDSimulationSettings, MDOutputSettings, + MultiStateSimulationSettings, OpenMMEngineSettings, + IntegratorSettings, MultiStateOutputSettings, + OpenFFPartialChargeSettings, + SettingsBaseModel, + HarmonicRestraintSettings, + FlatBottomRestraintSettings, + BoreschRestraintSettings, +) +from ..openmm_utils import system_validation, settings_validation +from .base import BaseAbsoluteUnit +from openfe.utils import log_system_probe +from openfe.due import due, Doi + + +due.cite(Doi("10.5281/zenodo.596504"), + description="Yank", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True) + +due.cite(Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True) + +due.cite(Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True) + + +logger = logging.getLogger(__name__) + + +class AbsoluteBindingProtocolResult(gufe.ProtocolResult): + """Dict-like container for the output of a AbsoluteBindingProtocol + """ + def __init__(self, **data): + super().__init__(**data) + # TODO: Detect when we have extensions and stitch these together? + if any(len(pur_list) > 2 for pur_list + in itertools.chain(self.data['solvent'].values(), self.data['vacuum'].values())): + raise NotImplementedError("Can't stitch together results yet") + + def get_individual_estimates(self) -> dict[str, list[tuple[unit.Quantity, unit.Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[unit.Quantity, unit.Quantity]]] + A dictionary, keyed `solvent`, `complex`, and 'standard_state' + representing each portion of the thermodynamic cycle, + with lists of tuples containing the individual free energy + estimates and, for 'solvent' and 'complex', the associated MBAR + uncertainties for each repeat of that simulation type. + + TODO + ---- + * Work out poperly what to do with the standard state correction. + """ + complex_dGs = [] + correction_dGs = [] + solv_dGs = [] + + for pus in self.data['complex'].values(): + complex_dGs.append(( + pus[0].outputs['unit_estimate'], + pus[0].outputs['unit_estimate_error'] + )) + correction_dGs.append(( + pus[0].outputs['standard_state_correction'] + )) + + for pus in self.data['solvent'].values(): + solv_dGs.append(( + pus[0].outputs['unit_estimate'], + pus[0].outputs['unit_estimate_error'] + )) + + return {'solvent': solv_dGs, 'complex': complex_dGs, 'standard_state': correction_dGs} + + def get_estimate(self): + """Get the binding free energy estimate for this calculation. + + Returns + ------- + dG : unit.Quantity + The binding free energy. This is a Quantity defined with units. + """ + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + complex_dG = _get_average(individual_estimates['complex']) + solv_dG = _get_average(individual_estimates['solvent']) + standard_state_dG = _get_average( + individual_estimates['standard_state'] + ) + + return - complex_dG + solv_dG + standard_state_dG + + def get_uncertainty(self): + """Get the binding free energy error for this calculation. + + Returns + ------- + err : unit.Quantity + The standard deviation between estimates of the binding free + energy. This is a Quantity defined with units. + """ + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + complex_err = _get_stdev(individual_estimates['complex']) + solv_err = _get_stdev(individual_estimates['solvent']) + standard_state_err = _get_stdev(individual_estimates['standard_state']) + + # return the combined error + return np.sqrt(complex_err**2 + solv_err**2 + standard_state_err**2) + + def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]]]: + """ + Get the reverse and forward analysis of the free energies. + + Returns + ------- + forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]]] + A dictionary, keyed `solvent` and `complex` for each leg of the + thermodynamic cycle which each contain a list of dictionaries + containing the forward and reverse analysis of each repeat + of that simulation type. + + The forward and reverse analysis dictionaries contain: + - `fractions`: npt.NDArray + The fractions of data used for the estimates + - `forward_DGs`, `reverse_DGs`: unit.Quantity + The forward and reverse estimates for each fraction of data + - `forward_dDGs`, `reverse_dDGs`: unit.Quantity + The forward and reverse estimate uncertainty for each + fraction of data. + + If one of the cycle leg list entries is ``None``, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + Raises + ------ + UserWarning + * If any of the forward and reverse dictionaries are ``None`` in a + given thermodynamic cycle leg. + """ + + forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]]] = {} + + for key in ['solvent', 'complex']: + forward_reverse[key] = [ + pus[0].outputs['forward_and_reverse_energies'] + for pus in self.data[key].values() + ] + + if None in forward_reverse[key]: + wmsg = ( + "One or more ``None`` entries were found in the forward " + f"and reverse dictionaries of the repeats of the {key} " + "calculations. This is likely caused by an MBAR convergence " + "failure caused by too few independent samples when " + "calculating the free energies of the 10% timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get a the MBAR overlap estimates for all legs of the simulation. + + Returns + ------- + overlap_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys `solvent` and `complex` for each + leg of the thermodynamic cycle, which each containing a + list of dictionaries with the MBAR overlap estimates of + each repeat of that simulation type. + + The underlying MBAR dictionaries contain the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + + for key in ['solvent', 'complex']: + overlap_stats[key] = [ + pus[0].outputs['unit_mbar_overlap'] + for pus in self.data[key].values() + ] + + return overlap_stats + + def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get the replica exchange transition statistics for all + legs of the simulation. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys `solvent` and `complex` for each + leg of the thermodynamic cycle, which each containing + a list of dictionaries containing the replica transition + statistics for each repeat of that simulation type. + + The replica transition statistics dictionaries contain the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + try: + for key in ['solvent', 'complex']: + repex_stats[key] = [ + pus[0].outputs['replica_exchange_statistics'] + for pus in self.data[key].values() + ] + except KeyError: + errmsg = ("Replica exchange statistics were not found, " + "did you run a repex calculation?") + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> dict[str, list[npt.NDArray]]: + """ + Get the timeseries of replica states for all simulation legs. + + Returns + ------- + replica_states : dict[str, list[npt.NDArray]] + Dictionary keyed `solvent` and `complex` for each leg of + the thermodynamic cycle, with lists of replica states + timeseries for each repeat of that simulation type. + """ + replica_states: dict[str, list[npt.NDArray]] = { + 'solvent': [], 'complex': [] + } + + def is_file(filename: str): + p = pathlib.Path(filename) + + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + + return p + + def get_replica_state(nc, chk): + nc = is_file(nc) + dir_path = nc.parents[0] + chk = is_file(dir_path / chk).name + + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode='r' + ) + + retval = np.asarray(reporter.read_replica_thermodynamic_states()) + reporter.close() + + return retval + + for key in ['solvent', 'complex']: + for pus in self.data[key].values(): + states = get_replica_state( + pus[0].outputs['nc'], + pus[0].outputs['last_checkpoint'], + ) + replica_states[key].append(states) + + return replica_states + + def equilibration_iterations(self) -> dict[str, list[float]]: + """ + Get the number of equilibration iterations for each simulation. + + Returns + ------- + equilibration_lengths : dict[str, list[float]] + Dictionary keyed `solvent` and `complex` for each leg + of the thermodynamic cycle, with lists containing the + number of equilibration iterations for each repeat + of that simulation type. + """ + equilibration_lengths: dict[str, list[float]] = {} + + for key in ['solvent', 'complex']: + equilibration_lengths[key] = [ + pus[0].outputs['equilibration_iterations'] + for pus in self.data[key].values() + ] + + return equilibration_lengths + + def production_iterations(self) -> dict[str, list[float]]: + """ + Get the number of production iterations for each simulation. + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : dict[str, list[float]] + Dictionary keyed `solvent` and `complex` for each leg of the + thermodynamic cycle, with lists with the number + of production iterations for each repeat of that simulation + type. + """ + production_lengths: dict[str, list[float]] = {} + + for key in ['solvent', 'complex']: + production_lengths[key] = [ + pus[0].outputs['production_iterations'] + for pus in self.data[key].values() + ] + + return production_lengths + + +class AbsoluteBindingProtocol(gufe.Protocol): + """ + Absolute binding free energy calculations using OpenMM and OpenMMTools. + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingSettings` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingProtocolResult` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingSolventUnit` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingComplexUnit` + """ + result_cls = AbsoluteBindingProtocolResult + _settings: AbsoluteBindingSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return AbsoluteBindingSettings( + protocol_repeats=3, + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * unit.kelvin, + pressure=1 * unit.bar, + ), + alchemical_settings=AlchemicalSettings(), + solvent_lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, + 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0], + lambda_restraints=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ), + complex_lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, + 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, + 0.2, 0.3, 0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], + lambda_restraints=[ + 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0], + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + solvation_settings=OpenMMSolvationSettings(), + engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + restraint_settings=BoreschRestraintSettings(), + solvent_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.1 * unit.nanosecond, + equilibration_length=0.2 * unit.nanosecond, + production_length=0.5 * unit.nanosecond, + ), + solvent_equil_output_settings=MDOutputSettings( + equil_nvt_structure='equil_nvt_structure.pdb', + equil_npt_structure='equil_npt_structure.pdb', + production_trajectory_filename='production_equil.xtc', + log_output='equil_simulation.log', + ), + solvent_simulation_settings=MultiStateSimulationSettings( + n_replicas=14, + equilibration_length=1.0 * unit.nanosecond, + production_length=10.0 * unit.nanosecond, + ), + solvent_output_settings=MultiStateOutputSettings( + output_filename='solvent.nc', + checkpoint_storage_filename='solvent_checkpoint.nc', + ), + complex_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.25 * unit.nanosecond, + equilibration_length=0.5 * unit.nanosecond, + production_length=5.0 * unit.nanosecond, + ), + complex_equil_output_settings=MDOutputSettings( + equil_nvt_structure='equil_nvt_structure.pdb', + equil_npt_structure='equil_npt_structure.pdb', + production_trajectory_filename='production_equil.xtc', + log_output='equil_simulation.log', + ), + complex_simulation_settings=MultiStateSimulationSettings( + n_replicas=28, + equilibration_length=1 * unit.nanosecond, + production_length=10.0 * unit.nanosecond, + ), + complex_output_settings=MultiStateOutputSettings( + output_filename='complex.nc', + checkpoint_storage_filename='complex_checkpoint.nc' + ), + ) + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, stateB: ChemicalSystem, + ) -> None: + """ + A binding transformation is defined (in terms of gufe components) + as starting from one or more ligands with one protein and solvent, + that then ends up in a state with one less ligand. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A + stateB : ChemicalSystem + The chemical system of end state B + + Raises + ------ + ValueError + If stateA does not contain a ProteinComponent + If stateA does not contain a SolventComponent + If stateA has more than one unique Component + If the stateA unique Component is not a SmallMoleculeComponent + If stateB contains any unique Components + """ + if not any( + isinstance(comp, ProteinComponent) for comp in stateA.values() + ): + errmsg = "No ProteinComponent found" + raise ValueError(errmsg) + + if not any( + isinstance(comp, SolventComponent) for comp in stateA.values() + ): + errmsg = "No SolventComponent found" + raise ValueError(errmsg) + + # Needs gufe 1.3 + diff = stateA.component_diff(stateB) + if len(diff[0]) > 1: + errmsg = ("More than unique components found in stateA, " + "only one alchemical species is supported") + raise ValueError(errmsg) + + if not isinstance(diff[0][0], SmallMoleculeComponent): + errmsg = ("Only dissapearing smalll molecule components " + "are supported by this protocol. " + f"Found a {type(diff[0][0])}") + raise ValueError(errmsg) + + if len(diff[1]) > 0: + errmsg = ("Unique components are found in stateB, " + "this should not happen") + raise ValueError(errmsg) + + @staticmethod + def _validate_lambda_schedule( + lambda_settings: LambdaSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> None: + """ + Checks that the lambda schedule is set up correctly. + + Parameters + ---------- + lambda_settings : LambdaSettings + the lambda schedule Settings + simulation_settings : MultiStateSimulationSettings + the settings for either the complex or solvent phase + + Raises + ------ + ValueError + If the number of lambda windows differs for electrostatics, sterics, + and restraints. + If the number of replicas does not match the number of lambda windows. + If there are states with naked charges. + """ + + lambda_elec = lambda_settings.lambda_elec + lambda_vdw = lambda_settings.lambda_vdw + lambda_restraints = lambda_settings.lambda_restraints + n_replicas = simulation_settings.n_replicas + + # Ensure that all lambda components have equal amount of windows + lambda_components = [lambda_vdw, lambda_elec, lambda_restraints] + it = iter(lambda_components) + the_len = len(next(it)) + if not all(len(l) == the_len for l in it): + errmsg = ( + "Components elec, vdw, and restraints must have equal amount" + f" of lambda windows. Got {len(lambda_elec)} elec lambda" + f" windows, {len(lambda_vdw)} vdw lambda windows, and" + f"{len(lambda_restraints)} restraints lambda windows.") + raise ValueError(errmsg) + + # Ensure that number of overall lambda windows matches number of lambda + # windows for individual components + if n_replicas != len(lambda_vdw): + errmsg = (f"Number of replicas {n_replicas} does not equal the" + f" number of lambda windows {len(lambda_vdw)}") + raise ValueError(errmsg) + + # Check if there are no lambda windows with naked charges + for inx, lam in enumerate(lambda_elec): + if lam < 1 and lambda_vdw[inx] == 1: + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: lambda {inx}: " + f"elec {lam} vdW {lambda_vdw[inx]}") + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # TODO: extensions + if extends: # pragma: no-cover + raise NotImplementedError("Can't extend simulations yet") + + # Validate components and get alchemical components + self._validate_endstates(stateA, stateB) + alchem_comps = system_validation.get_alchemical_components( + stateA, stateB, + ) + + # Validate the lambda schedule + self._validate_lambda_schedule(self.settings.solvent_lambda_settings, + self.settings.solvent_simulation_settings) + self._validate_lambda_schedule(self.settings.complex_lambda_settings, + self.settings.complex_simulation_settings) + + # Check nonbond & solvent compatibility + nonbonded_method = self.settings.forcefield_settings.nonbonded_method + # Use the more complete system validation solvent checks + system_validation.validate_solvent(stateA, nonbonded_method) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings( + self.settings.solvation_settings + ) + + # Get the name of the alchemical species + alchname = alchem_comps['stateA'][0].name + + # Create list units for complex and solvent transforms + + solvent_units = [ + AbsoluteBindingSolventUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, + generation=0, repeat_id=int(uuid.uuid4()), + name=(f"Absolute Binding, {alchname} solvent leg: " + f"repeat {i} generation 0"), + ) + for i in range(self.settings.protocol_repeats) + ] + + complex_units = [ + AbsoluteBindingComplexUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, + generation=0, repeat_id=int(uuid.uuid4()), + name=(f"Absolute Binding, {alchname} complex leg: " + f"repeat {i} generation 0"), + ) + for i in range(self.settings.protocol_repeats) + ] + + return solvent_units + complex_units + + def _gather( + self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] + ) -> dict[str, dict[str, Any]]: + # result units will have a repeat_id and generation + # first group according to repeat_id + unsorted_solvent_repeats = defaultdict(list) + unsorted_complex_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if not pu.ok(): + continue + if pu.outputs['simtype'] == 'solvent': + unsorted_solvent_repeats[pu.outputs['repeat_id']].append(pu) + else: + unsorted_complex_repeats[pu.outputs['repeat_id']].append(pu) + + repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { + 'solvent': {}, 'complex': {}, + } + for k, v in unsorted_solvent_repeats.items(): + repeats['solvent'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + + for k, v in unsorted_complex_repeats.items(): + repeats['complex'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + return repeats + + +class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the complex phase of an absolute binding free energy + """ + def _get_components(self): + """ + Get the relevant components for a complex transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs['stateA'] + alchem_comps = self._inputs['alchemical_components'] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp + return alchem_comps, solv_comp, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a complex transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : SimulationSettings + * output_settings: MultiStateOutputSettings + * restraint_settings: BaseRestraintSettings + """ + prot_settings = self._inputs['protocol'].settings + + settings = {} + settings['forcefield_settings'] = prot_settings.forcefield_settings + settings['thermo_settings'] = prot_settings.thermo_settings + settings['charge_settings'] = prot_settings.partial_charge_settings + settings['solvation_settings'] = prot_settings.solvation_settings + settings['alchemical_settings'] = prot_settings.alchemical_settings + settings['lambda_settings'] = prot_settings.complex_lambda_settings + settings['engine_settings'] = prot_settings.engine_settings + settings['integrator_settings'] = prot_settings.integrator_settings + settings['equil_simulation_settings'] = prot_settings.complex_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.complex_equil_output_settings + settings['simulation_settings'] = prot_settings.complex_simulation_settings + settings['output_settings'] = prot_settings.complex_output_settings + settings['restraint_settings'] = prot_settings.restraint_settings + + settings_validation.validate_timestep( + settings['forcefield_settings'].hydrogen_mass, + settings['integrator_settings'].timestep + ) + + return settings + + def _add_restraints( + self, + system: openmm.System, + topology: openmm.app.Topology, + settings: dict[str, SettingsBaseModel] + ) -> [GlobalParameterState, unit.Quantity, openmm.System]: + """ + Find and add restraints to the OpenMM System. + + Parameters + ---------- + system : openmm.System + The System to add the restraint to. + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + settings : dict[str, SettingsBaseModel] + A dictionary of settings that defines how to find and set + the restraint. + + Returns + ------- + restraint_parameter_state : RestraintParameterState + A RestraintParameterState object that defines the control + parameter for the restraint. + correction : unit.Quantity + The standard state correction for the restraint. + system : openmm.System + A copy of the System with the restraint added. + """ + from openfe.protocols.openmm_utils import ( + omm_restraints, geometry, search + ) + + if isinstance(settings['restraints_settings'], BoreschRestraintSettings): + geom = search.get_boresch_restraint( + topology, + self.shared_basepath / settings['equil_output_settings'].production_trajectory_filename + ) + + restraint = omm_restraints.BoreschRestraint( + settings['restraints_settings'], + geom, + controlling_parameter_name='lambda_restraints' + ) + else: + # TODO turn this into a direction for different restraint types supported? + raise NotImplementedError() + + # We need a temporary thermodynamic state to add the restraint + # & get the correction + thermodynamic_state = ThermodynamicState( + system, + temperature=to_openmm(settings['thermo_settings'].temperature), + pressure=to_openmm(settings['thermo_settings'].pressure), + ) + + # Add the force to the thermodynamic state + restraint.add_force(thermodynamic_state) + # Get the standard state correction as a unit.Quantity + correction = restraint.get_standard_state_correction(thermodynamic_state) + + # Get the GlobalParameterState for the restraint + retraint_parameter_state = omm_restraints.RestraintParameterState( + lambda_restraints=1.0 + ) + return restraint_parameter_state, correction, thermodynamic_state.system + + def _execute( + self, ctx: gufe.Context, **kwargs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared) + + return { + 'repeat_id': self._inputs['repeat_id'], + 'generation': self._inputs['generation'], + 'simtype': 'complex', + **outputs + } + + +class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the solvent phase of an absolute binding free energy + """ + def _get_components(self): + """ + Get the relevant components for a solvent transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs['stateA'] + alchem_comps = self._inputs['alchemical_components'] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp + return alchem_comps, solv_comp, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a solvent transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs['protocol'].settings + + settings = {} + settings['forcefield_settings'] = prot_settings.forcefield_settings + settings['thermo_settings'] = prot_settings.thermo_settings + settings['charge_settings'] = prot_settings.partial_charge_settings + settings['solvation_settings'] = prot_settings.solvation_settings + settings['alchemical_settings'] = prot_settings.alchemical_settings + settings['lambda_settings'] = prot_settings.solvent_lambda_settings + settings['engine_settings'] = prot_settings.engine_settings + settings['integrator_settings'] = prot_settings.integrator_settings + settings['equil_simulation_settings'] = prot_settings.solvent_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.solvent_equil_output_settings + settings['simulation_settings'] = prot_settings.solvent_simulation_settings + settings['output_settings'] = prot_settings.solvent_output_settings + + settings_validation.validate_timestep( + settings['forcefield_settings'].hydrogen_mass, + settings['integrator_settings'].timestep + ) + + return settings + + def _execute( + self, ctx: gufe.Context, **kwargs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared) + + return { + 'repeat_id': self._inputs['repeat_id'], + 'generation': self._inputs['generation'], + 'simtype': 'solvent', + **outputs + } diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index b77df9dfb..12d16dd30 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -889,7 +889,7 @@ def _get_components(self): def _handle_settings(self) -> dict[str, SettingsBaseModel]: """ - Extract the relevant settings for a vacuum transformation. + Extract the relevant settings for a solvent transformation. Returns ------- diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index 63cb5789c..d1e03ffb1 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -660,3 +660,98 @@ class Config: Filename for writing the log of the MD simulation, including timesteps, energies, density, etc. """ + +class BaseRestraintSettings(SettingsBaseModel): + """ + Settings contolling how to add restraints to a system. + """ + class Config: + arbitrary_types_allowed = True + + +class BaseDistanceRestraintSettings(BaseRestraintSettings): + """ + Base settings for a harmonic or flatbottom distance between two groups of + atoms. + """ + spring_constant: FloatQuantity['kilojoule_per_mole / nanometer**2'] + """ + The spring constant K between the two atom groups. + """ + atom_group1: Union[list[int], str] + """ + A definition for the first atom group to restrain. + Can either be a list of atom indices or an mdanalysis atom selection query. + """ + atom_group2: Union[list[int], str] + """ + A definition for the second atom group to restrain. + Can either be a list of atom indices or an mdanalysis atom selection query. + """ + +class HarmonicRestraintSettings(BaseDistanceRestraintSettings): + """ + Settings for a harmonic restraint between two groups of atoms. + """ + pass + + +class FlatBottomRestraintSettings(BaseDistanceRestraintSettings): + """ + Settings for a flat bottom restraint between two groups of atoms. + """ + well_radius: FloatQuantity['nanometer']] + """ + The well radius for the flat bottom restraint. + + TODO + ---- + * Implement an option to automatically pick the well radius. + """ + +class BoreschRestraintSettings(SettingsBaseModel): + """ + Settings for a Boresch-style restraint. + """ + host_atoms : Optional[list[int]] + """ + A list 3 host atom indices. + + TODO: How do you relate this back to your input? + """ + guest_atoms : Optional[list[int]] + """ + A list of 3 guest atom indices. + + TODO: How do you relate this back to your input? + """ + K_r: FloatQuantity['kilocalorie_per_mole / nm ** 2'] = 2000.0 * unit.kilocalorie_per_mole / unit.nm **2 + """ + The spring constant for the distance restraint between + host_atom[2] and guest_atom[0]. + """ + K_thetaA: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + angle(host_atoms[1], host_atoms[2], guest_atoms[2]) + """ + K_thetaB: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + angle(host_atoms[2], guest_atoms[0], guest_atoms[1]) + """ + K_phiA: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + dihedral(host_atoms[0], host_atoms[1], host_atoms[2], guest_atoms[0]) + """ + K_phiB: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + dihedral(host_atoms[1], host_atoms[2], guest_atoms[0], guest_atoms[1]) + """ + K_phiC: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + dihedral(host_atoms[2], guest_atoms[0], guest_atoms[1], guest_aotms[2]) + """ diff --git a/openfe/protocols/restraint_utils/__init__.py b/openfe/protocols/restraint_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/restraint_utils/geometry/__init__.py b/openfe/protocols/restraint_utils/geometry/__init__.py new file mode 100644 index 000000000..1c1b4c56a --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseRestraintGeometry +from .harmonic import DistanceRestraintGeometry +from .flatbottom import FlatBottomDistanceGeometry +from .boresch import BoreschRestraintGeometry diff --git a/openfe/protocols/restraint_utils/geometry/base.py b/openfe/protocols/restraint_utils/geometry/base.py new file mode 100644 index 000000000..0ca6ae200 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/base.py @@ -0,0 +1,48 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + + +class BaseRestraintGeometry(BaseModel, abc.ABC): + """ + A base class for a restraint geometry. + """ + class Config: + arbitrary_types_allowed = True + + +class HostGuestRestraintGeometry(BaseRestraintGeometry): + """ + An ordered list of guest atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + + guest_atoms: list[int] + """ + An ordered list of host atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + host_atoms: list[int] + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v diff --git a/openfe/protocols/restraint_utils/geometry/boresch.py b/openfe/protocols/restraint_utils/geometry/boresch.py new file mode 100644 index 000000000..6e740f48d --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/boresch.py @@ -0,0 +1,940 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import pathlib +from typing import Union, Optional, Iterable + +from rdkit import Chem + +import openmm +from openff.units import unit +from openff.models.types import FloatQuantity +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals +import numpy as np +import numpy.typing as npt +from scipy.stats import circmean + +from .base import HostGuestRestraintGeometry +from .utils import ( + _get_mda_coord_format, + _get_mda_topology_format, + get_aromatic_rings, + get_heavy_atom_idxs, + get_central_atom_idx, + is_collinear, + check_angular_variance, + check_dihedral_bounds, + check_angle_not_flat, + FindHostAtoms, + get_local_rmsf +) + + +class BoreschRestraintGeometry(HostGuestRestraintGeometry): + """ + A class that defines the restraint geometry for a Boresch restraint. + + The restraint is defined by the following: + + H2 G2 + - - + - - + H1 - - H0 -- G0 - - G1 + + Where HX represents the X index of ``host_atoms`` and GX + the X index of ``guest_atoms``. + """ + r_aA0: FloatQuantity['nanometer'] + """ + The equilibrium distance between H0 and G0. + """ + theta_A0: FloatQuantity['radians'] + """ + The equilibrium angle value between H1, H0, and G0. + """ + theta_B0: FloatQuantity['radians'] + """ + The equilibrium angle value between H0, G0, and G1. + """ + phi_A0: FloatQuantity['radians'] + """ + The equilibrium dihedral value between H2, H1, H0, and G0. + """ + phi_B0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H1, H0, G0, and G1. + """ + phi_C0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H0, G0, G1, and G2. + """ + + def get_bond_distance( + self, + universe: mda.Universe, + ) -> unit.Quantity: + """ + Get the H0 - G0 distance. + + Parameters + ---------- + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + bond : unit.Quantity + The H0-G0 distance. + """ + at1 = universe.atoms[self.host_atoms[0]] + at2 = universe.atoms[self.guest_atoms[0]] + bond = calc_bonds( + at1.position, + at2.position, + box=universe.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + def get_angles( + self, + universe: mda.Universe, + ) -> tuple[unit.Quantity, unit.Quantity]: + """ + Get the H1-H0-G0, and H0-G0-G1 angles. + + Parameters + ---------- + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + angleA : unit.Quantity + The H1-H0-G0 angle. + angleB : unit.Quantity + The H0-G0-G1 angle. + """ + at1 = universe.atoms[self.host_atoms[1]] + at2 = universe.atoms[self.host_atoms[0]] + at3 = universe.atoms[self.guest_atoms[0]] + at4 = universe.atoms[self.guest_atoms[1]] + + angleA = calc_angles( + at1.position, + at2.position, + at3.position, + box=universe.atoms.dimensions + ) + angleB = calc_angles( + at2.position, + at3.position, + at4.position, + box=universe.atoms.dimensions + ) + return angleA, angleB + + def get_dihedrals( + self, + universe: mda.Universe, + ) -> tuple[unit.Quantity, unit.Quantity, unit.Quantity]: + """ + Get the H2-H1-H0-G0, H1-H0-G0-G1, and H0-G0-G1-G2 dihedrals. + + Parameters + ---------- + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + dihA : unit.Quantity + The H2-H1-H0-G0 angle. + dihB : unit.Quantity + The H1-H0-G0-G1 angle. + dihC : unit.Quantity + The H0-G0-G1-G2 angle. + """ + at1 = universe.atoms[self.host_atoms[2]] + at2 = universe.atoms[self.host_atoms[1]] + at3 = universe.atoms[self.host_atoms[0]] + at4 = universe.atoms[self.guest_atoms[0]] + at5 = universe.atoms[self.guest_atoms[1]] + at6 = universe.atoms[self.guest_atoms[2]] + + dihA = calc_dihedrals( + at1.position, at2.position, at3.position, at4.position, + box=universe.dimensions + ) + dihB = calc_dihedrals( + at2.position, at3.position, at4.position, at5.position, + box=universe.dimensions + ) + dihC = calc_dihedrals( + at3.position, at4.position, at5.position, at6.position, + box=universe.dimensions + ) + return dihA, dihB, dihC + + +def _sort_by_distance_from_atom( + rdmol: Chem.Mol, target_idx: int, atom_idxs: Iterable[int] +) -> list[int]: + """ + Sort a list of RDMol atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the atom to measure from. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : Chem.Mol + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool( + rdmol: Chem.Mol, atom_idx: int, atom_pool: list[int] +) -> list[tuple[int, int, int]]: + """ + Get all bonded angles starting from ``atom_idx`` from a pool of atoms. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule + atom_idx : int + The index of the atom to search angles from. + atom_pool : list[int] + The list of indices to pick possible angle partners from. + + Returns + ------- + list[tuple[int, int, int]] + A list of tuples containing all the angles. + """ + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def _get_atom_pool( + rdmol: Chem.Mol, + rmsf: npt.NDArray, + rmsf_cutoff: unit.Quantity +) -> Optional[set[int]]: + """ + Filter atoms based on rmsf & rings, defaulting to heavy atoms if + there are not enough. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule to search through + rmsf : npt.NDArray + A 1-D array of RMSF values for each atom. + + Returns + ------- + atom_pool : Optional[set[int]] + """ + # Get a list of all the aromatic rings + # Note: no need to keep track of rings because we'll filter by + # bonded terms after, so if we only keep rings then all the bonded + # atoms should be within the same ring system. + atom_pool = set() + for ring in get_aromatic_rings(rdmol): + max_rmsf = rmsf[list(ring)].max() + if max_rmsf < rmsf_cutoff: + atom_pool.update(ring) + + # if we don't have enough atoms just get all the heavy atoms + if len(atom_pool) < 3: + heavy_atoms = get_heavy_atom_idxs(rdmol) + atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) + if len(atom_pool) < 3: + return None + + return atom_pool + + +def get_guest_atom_candidates( + topology: Union[str, pathlib.Path, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + rdmol: Chem.Mol, + guest_idxs: list[int], + rmsf_cutoff: unit.Quantity = 1 * unit.nanometer, +) -> list[tuple[int]]: + """ + Get a list of potential ligand atom choices for a Boresch restraint + being applied to a given small molecule. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + The topology of the system. + trajectory : Union[str, pathlib.Path] + A path to the system's coordinate trajectory. + rdmol : Chem.Mol + An RDKit Molecule representing the small molecule ordered in + the same way as it is listed in the topology. + guest_idxs : list[int] + The ligand indices in the topology. + rmsf_cutoff : unit.Quantity + The RMSF filter cut-off. + + Returns + ------- + angle_list : list[tuple[int]] + A list of tuples for each valid G0, G1, G2 angle. If ``None``, no + angles could be found. + + Raises + ------ + ValueError + If no suitable ligand atoms could be found. + + TODO + ---- + Should the RDMol have a specific frame position? + """ + u = mda.Universe( + topology, + trajectory, + format=_get_mda_coord_format(trajectory), + topology_format=_get_mda_topology_format(topology), + ) + + ligand_ag = u.atoms[guest_idxs] + + # 0. Get the ligand RMSF + rmsf = get_local_rmsf(ligand_ag) + u.trajectory[-1] # forward to the last frame + + # 1. Get the pool of atoms to work with + atom_pool = _get_atom_pool(rdmol, rmsf) + + if atom_pool is None: + # We don't have enough atoms so we raise an error + errmsg = "No suitable ligand atoms were found for the restraint" + raise ValueError(errmsg) + + # 2. Get the central atom + center = get_central_atom_idx(rdmol) + + # 3. Sort the atom pool based on their distance from the center + sorted_atom_pool = _sort_by_distance_from_atom(rdmol, center, atom_pool) + + # 4. Get a list of probable angles + angles_list = [] + for atom in sorted_atom_pool: + angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_atom_pool) + for angle in angles: + # Check that the angle is at least not collinear + angle_ag = ligand_ag.atoms[list(angle)] + if not is_collinear(ligand_ag.positions, angle, u.dimensions): + angles_list.append( + ( + angle_ag.atoms[0].ix, + angle_ag.atoms[1].ix, + angle_ag.atoms[2].ix + ) + ) + + return angles_list + + +def get_host_atom_candidates( + topology: Union[str, pathlib.Path, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + host_idxs: list[int], + l1_idx: int, + host_selection: str, + dssp_filter: bool = False, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, + min_distance: unit.Quantity = 1 * unit.nanometer, + max_distance: unit.Quantity = 3 * unit.nanometer, +) -> npt.NDArray: + """ + Get a list of suitable host atoms. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + The topology of the system. + trajectory : Union[str, pathlib.Path] + A path to the system's coordinate trajectory. + host_idxs : list[int] + A list of the host indices in the system topology. + l1_idx : int + The index of the proposed l1 binding atom. + host_selection : str + An MDAnalysis selection string to fileter the host by. + dssp_filter : bool + Whether or not to apply a DSSP filter on the host selection. + rmsf_cutoff : uni.Quantity + The maximum RMSF value allowwed for any candidate host atom. + min_distance : unit.Quantity + The minimum search distance around l1 for suitable candidate atoms. + max_distance : unit.Quantity + The maximum search distance around l1 for suitable candidate atoms. + + Return + ------ + NDArray + Array of host atom indexes + """ + u = mda.Universe( + topology, + trajectory, + format=_get_mda_coord_format(trajectory), + topology_format=_get_mda_topology_format(topology), + ) + + host_ag1 = u.atoms[host_idxs] + host_ag2 = host_ag1.select_atoms(host_selection) + + # 0. TODO: implement DSSP filter + # Should be able to just call MDA's DSSP method + # but will need to catch an exception + if dssp_filter: + raise NotImplementedError( + "DSSP filtering is not currently implemented" + ) + + # 1. Get the RMSF & filter + rmsf = get_local_rmsf(host_ag2) + protein_ag3 = host_ag2.atoms[rmsf < rmsf_cutoff] + + # 2. Search of atoms within the min/max cutoff + atom_finder = FindHostAtoms( + protein_ag3, u.atoms[l1_idx], min_distance, max_distance + ) + atom_finder.run() + return atom_finder.results.host_idxs + + +class EvaluateHostAtoms1(AnalysisBase): + """ + Class to evaluate the suitability of a set of host atoms + as H1 atoms (i.e. the second host atom). + + Parameters + ---------- + reference : MDAnalysis.AtomGroup + The reference preceeding three atoms. + host_atom_pool : MDAnalysis.AtomGroup + The pool of atoms to pick an atom from. + minimum_distance : unit.Quantity + The minimum distance from the bound reference atom. + angle_force_constant : unit.Quantity + The force constant for the angle. + temperature : unit.Quantity + The system temperature in Kelvin + """ + + def __init__( + self, + reference, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + **kwargs, + ): + super().__init__(reference.universe.trajectory, **kwargs) + + if len(reference) != 3: + errmsg = "Incorrect number of reference atoms passed" + raise ValueError(errmsg) + + self.reference = reference + self.host_atom_pool = host_atom_pool + self.minimum_distance = minimum_distance.to("angstrom").m + self.angle_force_constant = angle_force_constant + self.temperature = temperature + + def _prepare(self): + self.results.distances = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.angles = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + angle = calc_angles( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions, + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + dimensions=self.reference.dimensions, + ) + self.results.distances[i][self._frame_index] = distance + self.results.angles[i][self._frame_index] = angle + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance_bounds = all(self.results.distances[i] > self.minimum_distance) + mean_angle = circmean(self.results.angles[i], high=np.pi, low=0) + angle_bounds = check_angle_not_flat( + angle=mean_angle * unit.radians, + force_constant=self.angle_force_constant, + temperature=self.temperature, + ) + angle_variance = check_angular_variance( + self.results.angles[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=0 * unit.radians, + width=1.745 * unit.radians, + ) + mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) + dihed_bounds = check_dihedral_bounds(mean_dihed) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all( + [ + distance_bounds, + angle_bounds, + angle_variance, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): + self.results.valid[i] = True + + +class EvaluateHostAtoms2(EvaluateHostAtoms1): + def _prepare(self): + self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.ditances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance1 = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + distance2 = calc_bonds( + at.position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions, + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + dimensions=self.reference.dimensions, + ) + self.results.distances1[i][self._frame_index] = distance1 + self.results.distances2[i][self._frame_index] = distance2 + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance1_bounds = all(self.results.distances1[i] > self.minimum_distance) + distance2_bounds = all(self.results.distances2[i] > self.minimum_distance) + mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) + dihed_bounds = check_dihedral_bounds(mean_dihed) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all( + [ + distance1_bounds, + distance2_bounds, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): + self.results.valid[i] = True + + +def _find_host_angle( + g0g1g2_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature +): + h0_eval = EvaluateHostAtoms1( + g0g1g2_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) + h0_eval.run() + + for i, valid_h0 in enumerate(h0_eval.results.valid): + if valid_h0: + g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] + h1_eval = EvaluateHostAtoms1( + g1g2h0_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) + for j, valid_h1 in enumerate(h1_eval.results.valid): + g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] + h2_eval = EvaluateHostAtoms2( + g2h0h1_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) + + if any(h2_eval.ressults.valid): + d1_avgs = [d.mean() for d in h2_eval.results.distances1] + d2_avgs = [d.mean() for d in h2_eval.results.distances2] + dsum_avgs = d1_avgs + d2_avgs + k = dsum_avgs.argmin() + + return host_atom_pool.atoms[[i, j, k]].ix + return None + + +def _get_restraint_distances( + atomgroup: mda.AtomGroup +) -> tuple[unit.Quantity]: + """ + Get the bond, angle, and dihedral distances for an input atomgroup + defining the six atoms for a Boresch-like restraint. + + The atoms must be in the order of H0, H1, H2, G0, G1, G2. + + Parameters + ---------- + atomgroup : mda.AtomGroup + An AtomGroup defining the restrained atoms in order. + + Returns + ------- + bond : unit.Quantity + The H0-G0 bond value. + angle1 : unit.Quantity + The H1-H0-G0 angle value. + angle2 : unit.Quantity + The H0-G0-G1 angle value. + dihed1 : unit.Quantity + The H2-H1-H0-G0 dihedral value. + dihed2 : unit.Quantity + The H1-H0-G0-G1 dihedral value. + dihed3 : unit.Quantity + The H0-G0-G1-G2 dihedral value. + """ + + bond = calc_bonds( + atomgroup.atoms[0].position, + atomgroup.atoms[3], + box=atomgroup.dimensions + ) + + angles = [] + for idx_set in [[1, 0, 3], [0, 3, 4]]: + angle = calc_angles( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + box=atomgroup.dimensions, + ) + angles.append(angle * unit.radians) + + dihedrals = [] + for idx_set in [[2, 1, 0, 3], [1, 0, 3, 4], [0, 3, 4, 5]]: + dihed = calc_dihedrals( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + atomgroup.atoms[idx_set[3]].position, + box=atomgroup.dimensions, + ) + dihedrals.append(dihed * unit.radians) + + return bond, angles[0], angles[1], dihedrals[0], dihedrals[1], dihedrals[2] + + +def find_boresch_restraint( + topology: Union[str, pathlib.Path, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + guest_rdmol: Chem.Mol, + guest_idxs: list[int], + host_idxs: list[int], + guest_restraint_atoms_idxs: Optional[list[int]] = None, + host_restraint_atoms_idxs: Optional[list[int]] = None, + host_selection: str = "all", + dssp_filter: bool = False, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, + host_min_distance: unit.Quantity = 1 * unit.nanometer, + host_max_distance: unit.Quantity = 3 * unit.nanometer, + angle_force_constant: unit.Quantity = ( + 83.68 * unit.kilojoule_per_mole / unit.radians**2 + ), + temperature: unit.Quantity = 298.15 * unit.kelvin, +) -> BoreschRestraintGeometry: + """ + Find suitable Boresch-style restraints between a host and guest entity + based on the approach of Baumann et al. [1] with some modifications. + + Parameters + ---------- + topology : Union[str, pathlib.Path, openmm.app.Topology] + A topology of the system. + trajectory : Union[str, pathlib.Path] + A path to a coordinate trajectory file. + guest_rdmol : Chem.Mol + An RDKit Mol for the guest molecule. + guest_idxs : list[int] + Indices in the topology for the guest molecule. + host_idxs : list[int] + Indices in the topology for the host molecule. + guest_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the guest molecule itself (i.e. indexed + starting a 0 for the guest molecule). This overrides the + restraint search and a restraint using these indices will + be retruned. Must be defined alongside ``host_restraint_atoms_idxs``. + host_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the host molecule itself (i.e. indexed + starting a 0 for the hosts molecule). This overrides the + restraint search and a restraint using these indices will + be returnned. Must be defined alongside ``guest_restraint_atoms_idxs``. + host_selection : str + An MDAnalysis selection string to sub-select the host atoms. + dssp_filter : bool + Whether or not to filter the host atoms by their secondary structure. + rmsf_cutoff : unit.Quantity + The cutoff value for atom root mean square fluction. Atoms with RMSF + values above this cutoff will be disregarded. + Must be in units compatible with nanometer. + host_min_distance : unit.Quantity + The minimum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + host_max_distance : unit.Quantity + The maximum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + angle_force_constant : unit.Quantity + The force constant for the G1-G0-H0 and G0-H0-H1 angles. Must be + in units compatible with kilojoule / mole / radians ** 2. + temperature : unit.Quantity + The system temperature in units compatible with Kelvin. + + Returns + ------- + BoreschRestraintGeometry + An object defining the parameters of the Boresch-like restraint. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free energy + calculations using a Separated Topologies approach." (2023). + """ + u = mda.Universe( + topology, + trajectory, + format=_get_mda_coord_format(trajectory), + topology_format=_get_mda_topology_format(topology), + ) + + if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): # fmt: skip + # In this case assume the picked atoms were intentional / + # representative of the input and go with it + guest_ag = u.select_atoms[guest_idxs] + guest_angle = [ + at.ix for at in guest_ag.atoms[guest_restraint_atoms_idxs] + ] + host_ag = u.select_atoms[host_idxs] + host_angle = [ + at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] + ] + + # Set the equilibrium values as those of the final frame + u.trajectory[-1] + atomgroup = u.atoms[host_angle + guest_angle] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + + return BoreschRestraintGeometry( + host_atoms=host_angle, + guest_atoms=guest_angle, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 + ) + + if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # fmt: skip + # This is not an intended outcome, crash out here + errmsg = ( + "both ``guest_restraints_atoms_idxs`` and " + "``host_restraint_atoms_idxs`` " + "must be set or both must be None. " + f"Got {guest_restraint_atoms_idxs} and {host_restraint_atoms_idxs}" + ) + raise ValueError(errmsg) + + # 1. Fetch the guest angles + guest_angles = get_guest_atom_candidates( + topology=topology, + trajectory=trajectory, + rdmol=guest_rdmol, + guest_idxs=guest_idxs, + rmsf_cutoff=rmsf_cutoff, + ) + + if len(guest_angles) != 0: + errmsg = "No suitable ligand atoms found for the restraint." + raise ValueError(errmsg) + + # We pick the first angle / ligand atom set as the one to use + guest_angle = guest_angles[0] + + # 2. We next fetch the host atom pool + host_pool = get_host_atom_candidates( + topology=topology, + trajectory=trajectory, + host_idxs=host_idxs, + l1_idx=guest_angle[0], + host_selection=host_selection, + dssp_filter=dssp_filter, + rmsf_cutoff=rmsf_cutoff, + min_distance=host_min_distance, + max_distance=host_max_distance, + ) + + # 3. We then loop through the guest angles to find suitable host atoms + for guest_angle in guest_angles: + host_angle = _find_host_angle( + g0g1g2_atoms=u.atoms[list(guest_angle)], + host_atom_pool=u.atoms[host_pool], + minimum_distance=0.5 * unit.nanometer, + angle_force_constant=angle_force_constant, + temperature=temperature, + ) + # continue if it's empty, otherwise stop + if host_angle is not None: + break + + if host_angle is None: + errmsg = "No suitable host atoms could be found" + raise ValueError(errmsg) + + # Set the equilibrium values as those of the final frame + u.trajectory[-1] + atomgroup = u.atoms[host_angle + guest_angle] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + + return BoreschRestraintGeometry( + host_atoms=host_angle, + guest_atoms=guest_angle, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 + ) diff --git a/openfe/protocols/restraint_utils/geometry/flatbottom.py b/openfe/protocols/restraint_utils/geometry/flatbottom.py new file mode 100644 index 000000000..3b4599f56 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/flatbottom.py @@ -0,0 +1,127 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import pathlib +from typing import Union, Optional +import numpy as np +from openmm import app +from openff.units import unit +from openff.models.types import FloatQuantity +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds + +from .harmonic import ( + DistanceRestraintGeometry, +) + +from .utils import _get_mda_topology_format, _get_mda_selection + + +class FlatBottomDistanceGeometry(DistanceRestraintGeometry): + """ + A geometry class for a flat bottom distance restraint between two groups + of atoms. + """ + well_radius: FloatQuantity["nanometer"] + + +class COMDistanceAnalysis(AnalysisBase): + """ + Get a timeseries of COM distances between two AtomGroups + + Parameters + ---------- + group1 : MDAnalysis.AtomGroup + Atoms defining the first centroid. + group2 : MDANalysis.AtomGroup + Atoms defining the second centroid. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__(self, group1, group2, **kwargs): + super().__init__(group1.universe.trajectory, **kwargs) + + self.ag1 = group1 + self.ag2 = group2 + + def _prepare(self): + self.results.distances = np.zeros(self.n_frames) + + def _single_frame(self): + com_dist = calc_bonds( + self.ag1.center_of_mass(), + self.ag2.center_of_mass(), + box=self.ag1.universe.dimensions, + ) + self.results.distances[self._frame_index] = com_dist + + def _conclude(self): + pass + + +def get_flatbottom_distance_restraint( + topology: Union[str, app.Topology], + trajectory: Union[str, pathlib.Path], + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, + padding: unit.Quantity = 0.5 * unit.nanometer, +) -> FlatBottomDistanceGeometry: + """ + Get a FlatBottomDistanceGeometry by analyzing the COM distance + change between two sets of atoms. + + The ``well_radius`` is defined as the maximum COM distance plus + ``padding``. + + Parameters + ---------- + topology : Union[str, app.Topology] + A topology defining the system. + trajectory : Union[str, pathlib.Path] + A coordinate trajectory for the system. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + padding : unit.Quantity + A padding value to add to the ``well_radius`` definition. + Must be in units compatible with nanometers. + + Returns + ------- + FlatBottomDistanceGeometry + An object defining a flat bottom restraint geometry. + """ + u = mda.Universe( + topology, + trajectory, + topology_format=_get_mda_topology_format(topology) + ) + + guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) + host_ag = _get_mda_selection(u, host_atoms, host_selection) + + com_dists = COMDistanceAnalysis(guest_ag, host_ag) + com_dists.run() + + well_radius = com_dists.results.distances.max() * unit.angstrom + padding + return FlatBottomDistanceGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms, well_radius=well_radius + ) diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py new file mode 100644 index 000000000..197a8bc44 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -0,0 +1,144 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import pathlib +from typing import Union, Optional +from openmm import app +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds +from rdkit import Chem + +from .base import HostGuestRestraintGeometry +from .utils import ( + get_central_atom_idx, + _get_mda_selection, + _get_mda_topology_format, +) + + +class DistanceRestraintGeometry(HostGuestRestraintGeometry): + """ + A geometry class for a distance restraint between two groups of atoms. + """ + + def get_distance(self, universe: mda.Universe) -> unit.Quantity: + """ + Get the center of mass distance between the host and guest atoms. + + Parameters + ---------- + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + bond : unit.Quantity + The center of mass distance between the two groups of atoms. + """ + ag1 = universe.atoms[self.host_atoms] + ag2 = universe.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), + ag2.center_of_mass(), + box=universe.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def get_distance_restraint( + topology: Union[str, pathlib.Path, app.Topology], + trajectory: Union[str, pathlib.Path], + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, +) -> DistanceRestraintGeometry: + """ + Get a DistanceRestraintGeometry between two groups of atoms. + + You can either select the groups by passing through a set of indices + or an MDAnalysis selection. + + Parameters + ---------- + topology : Union[str, pathlib.Path, app.Topology] + A path or object defining the system topology. + trajectory : Union[str, pathlib.Path] + Coordinates for the system. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + u = mda.Universe( + topology, + trajectory, + topology_format=_get_mda_topology_format(topology) + ) + + guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) + guest_atoms = [a.ix for a in guest_ag] + host_ag = _get_mda_selection(u, host_atoms, host_selection) + host_atoms = [a.ix for a in host_ag] + + return DistanceRestraintGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms + ) + + +def get_molecule_centers_restraint( + molA_rdmol: Chem.Mol, + molB_rdmol: Chem.Mol, + molA_idxs: list[int], + molB_idxs: list[int], +): + """ + Get a DistanceRestraintGeometry between the central atoms of + two molecules. + + Parameters + ---------- + molA_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molB_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molA_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + molB_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + # We assume that the mol idxs are ordered + centerA = molA_idxs[get_central_atom_idx(molA_rdmol)] + centerB = molB_idxs[get_central_atom_idx(molB_rdmol)] + + return DistanceRestraintGeometry( + guest_atoms=[centerA], host_atoms=[centerB] + ) diff --git a/openfe/protocols/restraint_utils/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py new file mode 100644 index 000000000..4b734b410 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -0,0 +1,445 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Search methods for generating Geometry objects + +TODO +---- +* Add relevant duecredit entries. +""" +from typing import Union, Optional +import numpy as np +import numpy.typing as npt +from scipy.stats import circvar + +import openmm +from openff.toolkit import Molecule as OFFMol +from openff.units import unit +import networkx as nx +from rdkit import Chem +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.analysis.rms import RMSF +from MDAnalysis.lib.distances import minimize_vectors, capped_distance +from MDAnalysis.coordinates.memory import MemoryReader +from MDAnalysis.transformations.nojump import NoJump + +from openfe_analysis.transformations import Aligner + + +DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 + + +def _get_mda_selection( + universe: mda.Universe, + atom_list: Optional[list[int]], + selection: Optional[str] +) -> mda.AtomGroup: + """ + Return an AtomGroup based on either a list of atom indices or an + mdanalysis string selection. + + Parameters + ---------- + universe : mda.Universe + The MDAnalysis Universe to get the AtomGroup from. + atom_list : Optional[list[int]] + A list of atom indices. + selection : Optional[str] + An MDAnalysis selection string. + + Returns + ------- + ag : mda.AtomGroup + An atom group selected from the inputs. + + Raises + ------ + ValueError + If both ``atom_list`` and ``selection`` are ``None`` + or are defined. + """ + if atom_list is None: + if selection is None: + raise ValueError( + "one of either the atom lists or selections must be defined" + ) + + ag = universe.select_atoms(selection) + else: + if selection is not None: + raise ValueError( + "both atom_list and selection cannot be defined together" + ) + ag = universe.atoms[atom_list] + return ag + + +def _get_mda_coord_format( + coordinates: Union[str, npt.NDArray] +) -> Optional[MemoryReader]: + """ + Helper to set the coordinate format to MemoryReader + if the coordinates are an NDArray. + + Parameters + ---------- + coordinates : Union[str, npt.NDArray] + + Returns + ------- + Optional[MemoryReader] + Either the MemoryReader class or None. + """ + if isinstance(coordinates, npt.NDArray): + return MemoryReader + else: + return None + + +def _get_mda_topology_format( + topology: Union[str, openmm.app.Topology] +) -> Optional[str]: + """ + Helper to set the topology format to OPENMMTOPOLOGY + if the topology is an openmm.app.Topology. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + + + Returns + ------- + Optional[str] + The string `OPENMMTOPOLOGY` or None. + """ + if isinstance(topology, openmm.app.Topology): + return "OPENMMTOPOLOGY" + else: + return None + + +def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: + """ + Get a list of tuples with the indices for each ring in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molecule + + Returns + ------- + list[tuple[int]] + List of tuples for each ring. + """ + ringinfo = rdmol.GetRingInfo() + arom_idxs = get_aromatic_atom_idxs(rdmol) + + aromatic_rings = [] + + for ring in ringinfo.AtomRings(): + if all(a in arom_idxs for a in ring): + aromatic_rings.append(ring) + + return aromatic_rings + + +def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: + """ + Helper method to get aromatic atoms idxs + in a RDKit Molecule + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molecule + + Returns + ------- + list[int] + A list of the aromatic atom idxs + """ + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetIsAromatic()] + return idxs + + +def get_heavy_atom_idxs(rdmol: Chem.Mol) -> list[int]: + """ + Get idxs of heavy atoms in an RDKit Molecule + + Parameters + ---------- + rmdol : Chem.Mol + + Returns + ------- + list[int] + A list of heavy atom idxs + """ + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetAtomicNum() > 1] + return idxs + + +def get_central_atom_idx(rdmol: Chem.Mol) -> int: + """ + Get the central atom in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molcule to query + + Returns + ------- + center : int + Index of central atom in Molecule + + Note + ---- + If there are equal likelihood centers, will return + the first entry. + """ + # TODO: switch to a manual conversion to avoid an OpenFF dependency + offmol = OFFMol(rdmol, allow_undefined_stereo=True) + nx_mol = offmol.to_networkx() + if not nx.is_weakly_connected(nx_mol): + errmsg = "A disconnected molecule was passed, cannot find the center" + raise ValueError(errmsg) + + # We take the zero-th entry if there are multiple center + # atoms (e.g. equal likelihood centers) + center = nx.center(nx_mol)[0] + return center + + +def is_collinear(positions, atoms, dimensions=None, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + dimensions : Optional[npt.NDArray] + The dimensions of the system to minimize vectors. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; + False otherwise. + + Notes + ----- + Originally from Yank. + """ + result = False + for i in range(len(atoms) - 2): + v1 = minimize_vectors( + positions[atoms[i + 1], :] - positions[atoms[i], :], + box=dimensions, + ) + v2 = minimize_vectors( + positions[atoms[i + 2], :] - positions[atoms[i + 1], :], + box=dimensions, + ) + normalized_inner_product = np.dot(v1, v2) / np.sqrt( + np.dot(v1, v1) * np.dot(v2, v2) + ) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle_not_flat( + angle: unit.Quantity, + force_constant: unit.Quantity = DEFAULT_ANGLE_FRC_CONSTANT, + temperature: unit.Quantity = 298.15 * unit.kelvin, +) -> bool: + """ + Check whether the chosen angle is less than 10 kT from 0 or pi radians + + Parameters + ---------- + angle : unit.Quantity + The angle to check in units compatible with radians. + force_constant : unit.Quantity + Force constant of the angle in units compatible with + kilojoule_per_mole / radians ** 2. + temperature : unit.Quantity + The system temperature in units compatible with Kelvin. + + Returns + ------- + bool + False if the angle is less than 10 kT from 0 or pi radians + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # Convert things + angle_rads = angle.to("radians") + frc_const = force_constant.to("unit.kilojoule_per_mole / unit.radians**2") + temp_kelvin = temperature.to("kelvin") + RT = 8.31445985 * 0.001 * temp_kelvin + + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * frc_const * np.power((angle_rads - 0.0), 2) + check2 = 0.5 * frc_const * np.power((angle_rads - np.pi), 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + +def check_dihedral_bounds( + dihedral: unit.Quantity, + lower_cutoff: unit.Quantity = 2.618 * unit.radians, + upper_cutoff: unit.Quantity = -2.618 * unit.radians, +) -> bool: + """ + Check that a dihedral does not exceed the bounds set by + lower_cutoff and upper_cutoff. + + Parameters + ---------- + dihedral : unit.Quantity + Dihedral in units compatible with radians. + lower_cutoff : unit.Quantity + Dihedral lower cutoff in units compatible with radians. + upper_cutoff : unit.Quantity + Dihedral upper cutoff in units compatible with radians. + + Returns + ------- + bool + ``True`` if the dihedral is within the upper and lower + cutoff bounds. + """ + if (dihedral < lower_cutoff) or (dihedral > upper_cutoff): + return False + return True + + +def check_angular_variance( + angles: unit.Quantity, + upper_bound: unit.Quantity, + lower_bound: unit.Quantity, + width: unit.Quantity, +) -> bool: + """ + Check that the variance of a list of ``angles`` does not exceed + a given ``width`` + + Parameters + ---------- + angles : ArrayLike[unit.Quantity] + An array of angles in units compatible with radians. + upper_bound: unit.Quantity + The upper bound in the angle range in radians compatible units. + lower_bound: unit.Quantity + The lower bound in the angle range in radians compatible units. + width : unit.Quantity + The width to check the variance against, in units compatible with + radians. + + Returns + ------- + bool + ``True`` if the variance of the angles is less than the width. + + """ + variance = circvar( + angles.to("radians").m, + high=upper_bound.to("radians").m, + low=lower_bound.to("radians").m, + ) + return not (variance * unit.radians > width) + + +class FindHostAtoms(AnalysisBase): + """ + Class filter host atoms based on their distance + from a set of guest atoms. + + Parameters + ---------- + host_atoms : MDAnalysis.AtomGroup + Initial selection of host atoms to filter from. + guest_atoms : MDANalysis.AtomGroup + Selection of guest atoms to search around. + min_search_distance: unit.Quantity + Minimum distance to filter atoms within. + max_search_distance: unit.Quantity + Maximum distance to filter atoms within. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__( + self, + host_atoms, + guest_atoms, + min_search_distance, + max_search_distance, + **kwargs, + ): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.host_ag = host_atoms + self.guest_ag = guest_atoms + self.min_cutoff = min_search_distance.to("angstrom").m + self.max_cutoff = max_search_distance.to("angstrom").m + + def _prepare(self): + self.results.host_idxs = set() + + def _single_frame(self): + pairs = capped_distance( + reference=self.host_ag.positions, + configuration=self.guest_ag.positions, + max_cutoff=self.max_cutoff, + min_cutoff=self.min_cutoff, + box=self.guest_ag.universe.dimensions, + return_distances=False, + ) + + host_idxs = [self.guest_ag.atoms[p].ix for p in pairs[:, 1]] + self.results.host_idxs.update(set(host_idxs)) + + def _conclude(self): + self.results.host_idxs = np.array(self.results.host_idxs) + + +def get_local_rmsf(atomgroup: mda.AtomGroup) -> unit.Quantity: + """ + Get the RMSF of an AtomGroup when aligned upon itself. + + Parameters + ---------- + atomgroup : MDAnalysis.AtomGroup + + Return + ------ + rmsf + ArrayQuantity of RMSF values. + """ + # First let's copy our Universe + copy_u = atomgroup.universe.copy() + ag = copy_u.atoms[atomgroup.atoms.ix] + + nojump = NoJump() + align = Aligner(ag) + + copy_u.trajectory.add_transformations(nojump, align) + + rmsf = RMSF(ag) + rmsf.run() + return rmsf.results.rmsf * unit.angstrom diff --git a/openfe/protocols/restraint_utils/openmm/__init__.py b/openfe/protocols/restraint_utils/openmm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/restraint_utils/openmm/omm_forces.py b/openfe/protocols/restraint_utils/openmm/omm_forces.py new file mode 100644 index 000000000..2947c8e03 --- /dev/null +++ b/openfe/protocols/restraint_utils/openmm/omm_forces.py @@ -0,0 +1,129 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Custom OpenMM Forces + +TODO +---- +* Add relevant duecredit entries. +""" +import numpy as np +import openmm + + +def get_boresch_energy_function( + control_parameter: str, +) -> str: + """ + Return a Boresch-style energy function for a CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + return energy_function + + +def get_periodic_boresch_energy_function( + control_parameter: str, +) -> str: + """ + Return a Boresch-style energy function with a periodic torsion for a + CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + return energy_function + + +def get_custom_compound_bond_force( + energy_function: str, n_particles: int = 6, +): + """ + Return an OpenMM CustomCompoundForce + + TODO + ---- + Change this to a direct subclass like openmmtools.force. + + Acknowledgements + ---------------- + Boresch-like energy functions are reproduced from `Yank `_ + """ + return openmm.CustomCompoundBondForce(n_particles, energy_function) + + +def add_force_in_separate_group( + system: openmm.System, + force: openmm.Force, +): + """ + Add force to a System in a separate force group. + + Parameters + ---------- + system : openmm.System + System to add the Force to. + force : openmm.Force + The Force to add to the System. + + Raises + ------ + ValueError + If all 32 force groups are occupied. + + + TODO + ---- + Unlike the original Yank implementation, we assume that + all 32 force groups will not be filled. Should this be an issue + we can consider just separating it from NonbondedForce. + + Acknowledgements + ---------------- + Mostly reproduced from `Yank `_. + """ + available_force_groups = set(range(32)) + for existing_force in system.getForces(): + available_force_groups.discard(existing_force.getForceGroup()) + + if len(available_force_groups) == 0: + errmsg = "No available force groups could be found" + raise ValueError(errmsg) + + force.setForceGroup(min(available_force_groups)) + system.addForce(force) diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py new file mode 100644 index 000000000..c77b1cd0b --- /dev/null +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -0,0 +1,697 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Classes for applying restraints to OpenMM Systems. + +Acknowledgements +---------------- +Many of the classes here are at least in part inspired from +`Yank `_ and +`OpenMMTools `_. + +TODO +---- +* Add relevant duecredit entries. +* Add Periodic Torsion Boresch class +""" +import abc + +import numpy as np +import openmm +from openmm import unit as omm_unit +from openmmtools.forces import ( + HarmonicRestraintForce, + HarmonicRestraintBondForce, + FlatBottomRestraintForce, + FlatBottomRestraintBondForce, +) +from openmmtools.states import GlobalParameterState, ThermodynamicState +from openff.units.openmm import to_openmm, from_openmm +from openff.units import unit + +from gufe.settings.models import SettingsBaseModel + +from openfe.protocols.restraint_utils.geometry import ( + BaseRestraintGeometry, + DistanceRestraintGeometry, + BoreschRestraintGeometry +) +from .omm_forces import ( + get_custom_compound_bond_force, + add_force_in_separate_group, + get_boresch_energy_function, +) + + +class RestraintParameterState(GlobalParameterState): + """ + Composable state to control `lambda_restraints` OpenMM Force parameters. + + See :class:`openmmtools.states.GlobalParameterState` for more details. + + Parameters + ---------- + parameters_name_suffix : Optional[str] + If specified, the state will control a modified version of the parameter + ``lambda_restraints_{parameters_name_suffix}` instead of just + ``lambda_restraints``. + lambda_restraints : Optional[float] + The strength of the restraint. If defined, must be between 0 and 1. + + Acknowledgement + --------------- + Partially reproduced from Yank. + """ + + lambda_restraints = GlobalParameterState.GlobalParameter( + "lambda_restraints", standard_value=1.0 + ) + + @lambda_restraints.validator + def lambda_restraints(self, instance, new_value): + if new_value is not None and not (0.0 <= new_value <= 1.0): + errmsg = ( + "lambda_restraints must be between 0.0 and 1.0 " + f"and got {new_value}" + ) + raise ValueError(errmsg) + # Not crashing out on None to match upstream behaviour + return new_value + + +class BaseHostGuestRestraints(abc.ABC): + """ + An abstract base class for defining objects that apply a restraint between + two entities (referred to as a Host and a Guest). + + + TODO + ---- + Add some developer examples here. + """ + + def __init__( + self, + restraint_settings: SettingsBaseModel, + ): + self.settings = restraint_settings + self._verify_inputs() + + @abc.abstractmethod + def _verify_inputs(self): + """ + Method for validating that the inputs to the class are correct. + """ + pass + + @abc.abstractmethod + def _verify_geometry(self, geometry): + """ + Method for validating that the geometry object passed is correct. + """ + pass + + @abc.abstractmethod + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, + ): + """ + Method for in-place adding the Force to the System of a + ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ + pass + + @abc.abstractmethod + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry + ) -> unit.Quantity: + """ + Get the standard state correction for the Force when + applied to the input ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ + pass + + @abc.abstractmethod + def _get_force( + self, + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, + ): + """ + Helper method to get the relevant OpenMM Force for this + class, given an input geometry. + """ + pass + + +class SingleBondMixin: + """ + A mixin to extend geometry checks for Forces that can only hold + a single atom. + """ + def _verify_geometry(self, geometry: BaseRestraintGeometry): + if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: + errmsg = ( + "host_atoms and guest_atoms must only include a single index " + f"each, got {len(geometry.host_atoms)} and " + f"{len(geometry.guest_atoms)} respectively." + ) + raise ValueError(errmsg) + super()._verify_geometry(geometry) + + +class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): + """ + A base class for all radially symmetic Forces acting between + two sets of atoms. + + Must be subclassed. + """ + def _verify_inputs(self) -> None: + if not isinstance(self.settings, DistanceRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + + def _verify_geometry(self, geometry: DistanceRestraintGeometry): + if not isinstance(geometry, DistanceRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" + raise ValueError(errmsg) + + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str = "lambda_restraints", + ) -> None: + """ + Method for in-place adding the Force to the System of the + given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ + self._verify_geometry(geometry) + force = self._get_force(geometry, controlling_parameter_name) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry, + ) -> unit.Quantity: + """ + Get the standard state correction for the Force when + applied to the input ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ + self._verify_geometry(geometry) + force = self._get_force(geometry) + corr = force.compute_standard_state_correction( + thermodynamic_state, volume="system" + ) + dg = corr * thermodynamic_state.kT + return from_openmm(dg).to('kilojoule_per_mole') + + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str + ): + raise NotImplementedError("only implemented in child classes") + + +class HarmonicBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): + """ + A class to add a harmonic restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintBondForce + An OpenMM Force that applies a harmonic restraint between + two atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + return HarmonicRestraintBondForce( + spring_constant=spring_constant, + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], + controlling_parameter_name=controlling_parameter_name, + ) + + +class FlatBottomBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): + """ + A class to add a flat bottom restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintBondForce + An OpenMM Force that applies a flat bottom restraint between + two atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) + return FlatBottomRestraintBondForce( + spring_constant=spring_constant, + well_radius=well_radius, + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], + controlling_parameter_name=controlling_parameter_name, + ) + + +class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): + """ + A class to add a harmonic restraint between the centroid of + two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintForce + An OpenMM Force that applies a harmonic restraint between + the centroid of two sets of atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + return HarmonicRestraintForce( + spring_constant=spring_constant, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, + controlling_parameter_name=controlling_parameter_name, + ) + + +class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): + """ + A class to add a flat bottom restraint between the centroid + of two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintForce + An OpenMM Force that applies a flat bottom restraint between + the centroid of two sets of atoms. + """ + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) + return FlatBottomRestraintForce( + spring_constant=spring_constant, + well_radius=well_radius, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, + controlling_parameter_name=controlling_parameter_name, + ) + + +class BoreschRestraint(BaseHostGuestRestraints): + """ + A class to add a Boresch-like restraint between six atoms, + + The restraint is defined as a + :class:`openmmtools.forces.CustomCompoundForce` with the + following energy function: + + lambda_control_parameter * E; + E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 + + (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 + + (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + + (K_phiC/2)*dphi_C^2; + dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); + dA = dihedral(p1,p2,p3,p4) - phi_A0; + dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); + dB = dihedral(p2,p3,p4,p5) - phi_B0; + dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); + dC = dihedral(p3,p4,p5,p6) - phi_C0; + + Where p1, p2, p3, p4, p5, p6 represent host atoms 2, 1, 0, + and guest atoms 0, 1, 2 respectively. + + ``lambda_control_parameter`` is a control parameter for + scaling the Force. + + ``K_r`` is defined as the bond spring constant between + p3 and p4 and must be provided in the settings in units + compatible with kilojoule / mole. + + ``r_aA0`` is the equilibrium distance of the bond between + p3 and p4. This must be provided by the Geometry class in + units compatiblle with nanometer. + + ``K_thetaA`` and ``K_thetaB`` are the spring constants for the angles + formed by (p2, p3, p4) and (p3, p4, p5). They must be provided in the + settings in units compatible with kilojoule / mole / radians**2. + + ``theta_A0`` and ``theta_B0`` are the equilibrium values for angles + (p2, p3, p4) and (p3, p4, p5). They must be provided by the + Geometry class in units compatible with radians. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium constants + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the settings in units + compatible with kilojoule / mole / radians ** 2. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium values + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the Geometry class in + units compatible with radians. + + + Notes + ----- + * Settings must define the ``K_r`` (d) + """ + def _verify_inputs(self) -> None: + """ + Method for validating that the geometry object is correct. + """ + if not isinstance(self.settings, BoreschRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + + def _verify_geometry(self, geometry: BoreschRestraintGeometry): + """ + Method for validating that the geometry object is correct. + """ + if not isinstance(geometry, BoreschRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" + raise ValueError(errmsg) + + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str, + ) -> None: + """ + Method for in-place adding the Boresch CustomCompoundForce + to the System of the given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ + self._verify_geometry(geometry) + force = self._get_force( + geometry, + controlling_parameter_name, + ) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def _get_force( + self, + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str + ) -> openmm.CustomCompoundBondForce: + """ + Get the CustomCompoundForce with a Boresch-like energy function + given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + CustomCompoundForce + An OpenMM CustomCompoundForce that applies a Boresch-like + restraint between 6 atoms. + """ + efunc = get_boresch_energy_function(controlling_parameter_name) + + force = get_custom_compound_bond_force( + energy_function=efunc, n_particles=6, + ) + + param_values = [] + + parameter_dict = { + 'K_r': self.settings.K_r, + 'r_aA0': geometry.r_aA0, + 'K_thetaA': self.settings.K_thetaA, + 'theta_A0': geometry.theta_A0, + 'K_thetaB': self.settings.K_thetaB, + 'theta_B0': geometry.theta_B0, + 'K_phiA': self.settings.K_phiA, + 'phi_A0': geometry.phi_A0, + 'K_phiB': self.settings.K_phiB, + 'phi_B0': geometry.phi_B0, + 'K_phiC': self.settings.K_phiC, + 'phi_C0': geometry.phi_C0, + } + for key, val in parameter_dict.items(): + param_values.append( + to_openmm(val).value_in_unit_system(omm_unit.md_unit_system) + ) + force.addPerBondParameter(key) + + force.addGlobalParameter(controlling_parameter_name, 1.0) + atoms = [ + geometry.host_atoms[2], + geometry.host_atoms[1], + geometry.host_atoms[0], + geometry.guest_atoms[0], + geometry.guest_atoms[1], + geometry.guest_atoms[2], + ] + force.addBond(atoms, param_values) + return force + + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry + ) -> unit.Quantity: + """ + Get the standard state correction for the Boresch-like + restraint when applied to the input ThermodynamicState. + + The correction is calculated using the analytical method + as defined by Boresch et al. [1] + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + + References + ---------- + [1] Boresch S, Tettinger F, Leitgeb M, Karplus M. J Phys Chem B. 107:9535, 2003. + http://dx.doi.org/10.1021/jp0217839 + """ + self._verify_geometry(geometry) + + StandardV = 1.66053928 * unit.nanometer**3 + kt = from_openmm(thermodynamic_state.kT) + + # distances + r_aA0 = geometry.r_aA0.to('nm') + sin_thetaA0 = np.sin(geometry.theta_A0.to('radians')) + sin_thetaB0 = np.sin(geometry.theta_B0.to('radians')) + + # restraint energies + K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') + K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole / radians ** 2') + K_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') + K_phiA = self.settings.K_phiA.to('kilojoule_per_mole / radians ** 2') + K_phiB = self.settings.K_phiB.to('kilojoule_per_mole / radians ** 2') + K_phiC = self.settings.K_phiC.to('kilojoule_per_mole / radians ** 2') + + numerator1 = 8.0 * (np.pi**2) * StandardV + denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 + numerator2 = np.sqrt( + K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC + ) + denum2 = (2.0 * np.pi * kt)**3 + + dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) + + return dG diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py new file mode 100644 index 000000000..0c12aef17 --- /dev/null +++ b/openfe/protocols/restraint_utils/settings.py @@ -0,0 +1,23 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Settings for adding restraints. +""" +from typing import Optional, Literal +from openff.units import unit +from openff.models.types import FloatQuantity, ArrayQuantity + +from gufe.settings import ( + SettingsBaseModel, +) + + +from pydantic.v1 import validator + + +class BaseRestraintSettings(SettingsBaseModel): + """ + Base class for RestraintSettings objects. + """ + class Config: + arbitrary_types_allowed = True diff --git a/openfe/tests/protocols/restraints/__init__.py b/openfe/tests/protocols/restraints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/tests/protocols/restraints/test_geometry_base.py b/openfe/tests/protocols/restraints/test_geometry_base.py new file mode 100644 index 000000000..139c57dc5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_geometry_base.py @@ -0,0 +1,25 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.geometry.base import ( + HostGuestRestraintGeometry +) + + +def test_hostguest_geometry(): + """ + A very basic will it build test. + """ + geom = HostGuestRestraintGeometry(guest_atoms=[1, 2, 3], host_atoms=[4]) + + assert isinstance(geom, HostGuestRestraintGeometry) + + +def test_hostguest_positiveidxs_validator(): + """ + Check that the validator is working as intended. + """ + with pytest.raises(ValueError, match="negative indices passed"): + geom = HostGuestRestraintGeometry(guest_atoms=[-1, 1], host_atoms=[0]) diff --git a/openfe/tests/protocols/restraints/test_omm_restraints.py b/openfe/tests/protocols/restraints/test_omm_restraints.py new file mode 100644 index 000000000..0e346f9c5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_omm_restraints.py @@ -0,0 +1,31 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.openmm.omm_restraints import ( + RestraintParameterState, +) + + +def test_parameter_state_default(): + param_state = RestraintParameterState() + assert param_state.lambda_restraints is None + + +@pytest.mark.parametrize('suffix', [None, 'foo']) +@pytest.mark.parametrize('lambda_var', [0, 0.5, 1.0]) +def test_parameter_state_suffix(suffix, lambda_var): + param_state = RestraintParameterState( + parameters_name_suffix=suffix, lambda_restraints=lambda_var + ) + + if suffix is not None: + param_name = f'lambda_restraints_{suffix}' + else: + param_name = 'lambda_restraints' + + assert getattr(param_state, param_name) == lambda_var + assert len(param_state._parameters.keys()) == 1 + assert param_state._parameters[param_name] == lambda_var + assert param_state._parameters_name_suffix == suffix diff --git a/openfe/tests/protocols/restraints/test_openmm_forces.py b/openfe/tests/protocols/restraints/test_openmm_forces.py new file mode 100644 index 000000000..cd2a7f21e --- /dev/null +++ b/openfe/tests/protocols/restraints/test_openmm_forces.py @@ -0,0 +1,115 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest +import numpy as np +import openmm +from openfe.protocols.restraint_utils.openmm.omm_forces import ( + get_boresch_energy_function, + get_periodic_boresch_energy_function, + get_custom_compound_bond_force, + add_force_in_separate_group, +) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_periodic_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_periodic_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('num_atoms', [6, 20]) +def test_custom_compound_force(num_atoms): + fn = get_boresch_energy_function('lambda_restraints') + force = get_custom_compound_bond_force(fn, num_atoms) + + # Check we have the right object + assert isinstance(force, openmm.CustomCompoundBondForce) + + # Check the energy function + assert force.getEnergyFunction() == fn + + # Check the number of particles + assert force.getNumParticlesPerBond() == num_atoms + + +@pytest.mark.parametrize('groups, expected', [ + [[0, 1, 2, 3, 4], 5], + [[1, 2, 3, 4, 5], 0], +]) +def test_add_force_in_separate_group(groups, expected): + # Create an empty system + system = openmm.System() + + # Create some forces with some force groups + base_forces = [ + openmm.NonbondedForce(), + openmm.HarmonicBondForce(), + openmm.HarmonicAngleForce(), + openmm.PeriodicTorsionForce(), + openmm.CMMotionRemover(), + ] + + for force, group in zip(base_forces, groups): + force.setForceGroup(group) + + [system.addForce(force) for force in base_forces] + + # Get your CustomCompoundBondForce + fn = get_boresch_energy_function('lambda_restraints') + new_force = get_custom_compound_bond_force(fn, 6) + # new_force.setForceGroup(5) + # system.addForce(new_force) + add_force_in_separate_group(system=system, force=new_force) + + # Loop through and check that we go assigned the expected force group + for force in system.getForces(): + if isinstance(force, openmm.CustomCompoundBondForce): + assert force.getForceGroup() == expected + + +def test_add_too_many_force_groups(): + # Create a system + system = openmm.System() + + # Fill it upu with 32 forces with different groups + for i in range(32): + f = openmm.HarmonicBondForce() + f.setForceGroup(i) + system.addForce(f) + + # Now try to add another force + with pytest.raises(ValueError, match="No available force group"): + add_force_in_separate_group( + system=system, force=openmm.HarmonicBondForce() + ) \ No newline at end of file