Skip to content

Commit

Permalink
Compute selection: deviceIndex & enforce 1 thread in vacuum (#752)
Browse files Browse the repository at this point in the history
* Fix openmm compute platform selection for issues 739 and 704

* Add rever entry

* fix typing

* Remove erroneous extra file

* fix gufe keys

* Update news/compute_selection_fixes.rst

Co-authored-by: Mike Henry <[email protected]>

* Update omm_compute.py

* import os

* NETCDF3_64BIT is now NETCDF3_64BIT_OFFSET

---------

Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
IAlibay and mikemhenry authored Dec 5, 2024
1 parent be3433c commit 72d623a
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 31 deletions.
26 changes: 26 additions & 0 deletions news/compute_selection_fixes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
**Added:**

* OpenMMEngineSettings now has a `gpu_device_index` attribute
allowing users to pass through a list of ints to select the
GPU devices to run their simulations on.

**Changed:**

* `openfe.protocols.openmm_rfe._rfe_utils.compute` has been moved
to `openfe.protocols.openmm_utils.omm_compute`.

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* OpenMM CPU vacuum calculations now enforce the use of a single CPU to avoid large performance losses.

**Security:**

* <news item>
28 changes: 21 additions & 7 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@
MultiStateSimulationSettings, OpenMMEngineSettings,
IntegratorSettings, LambdaSettings, MultiStateOutputSettings,
ThermoSettings, OpenFFPartialChargeSettings,
OpenMMSystemGeneratorFFSettings,
)
from openfe.protocols.openmm_rfe._rfe_utils import compute
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit
from ..openmm_utils import (
settings_validation, system_creation,
multistate_analysis, charge_generation
multistate_analysis, charge_generation,
omm_compute,
)
from openfe.utils import without_oechem_backend

Expand Down Expand Up @@ -175,6 +176,7 @@ def _pre_equilibrate(
settings : dict[str, SettingsBaseModel]
A dictionary of settings objects. Expects the
following entries:
* `forcefield_settings`
* `engine_settings`
* `thermo_settings`
* `integrator_settings`
Expand All @@ -189,8 +191,12 @@ def _pre_equilibrate(
Equilibrated system positions
"""
# Prep the simulation object
platform = compute.get_openmm_platform(
settings['engine_settings'].compute_platform
# Restrict CPU count if running vacuum simulation
restrict_cpu = settings['forcefield_settings'].nonbonded_method.lower() == 'nocutoff'
platform = omm_compute.get_openmm_platform(
platform_name=settings['engine_settings'].compute_platform,
gpu_device_index=settings['engine_settings'].gpu_device_index,
restrict_cpu_count=restrict_cpu
)

integrator = openmm.LangevinMiddleIntegrator(
Expand Down Expand Up @@ -710,14 +716,16 @@ def _get_reporter(

def _get_ctx_caches(
self,
forcefield_settings: OpenMMSystemGeneratorFFSettings,
engine_settings: OpenMMEngineSettings
) -> tuple[openmmtools.cache.ContextCache, openmmtools.cache.ContextCache]:
"""
Set the context caches based on the chosen platform
Parameters
----------
engine_settings : OpenMMEngineSettings,
forcefield_settings: OpenMMSystemGeneratorFFSettings
engine_settings : OpenMMEngineSettings
Returns
-------
Expand All @@ -726,8 +734,13 @@ def _get_ctx_caches(
sampler_context_cache : openmmtools.cache.ContextCache
The sampler state context cache.
"""
platform = compute.get_openmm_platform(
engine_settings.compute_platform,
# Get the compute platform
# Set the number of CPUs to 1 if running a vacuum simulation
restrict_cpu = forcefield_settings.nonbonded_method.lower() == 'nocutoff'
platform = omm_compute.get_openmm_platform(
platform_name=engine_settings.compute_platform,
gpu_device_index=engine_settings.gpu_device_index,
restrict_cpu_count=restrict_cpu
)

energy_context_cache = openmmtools.cache.ContextCache(
Expand Down Expand Up @@ -1026,6 +1039,7 @@ def run(self, dry=False, verbose=True,
try:
# 12. Get context caches
energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches(
settings['forcefield_settings'],
settings['engine_settings']
)

Expand Down
10 changes: 6 additions & 4 deletions openfe/protocols/openmm_md/plain_md_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@
)
from openff.toolkit.topology import Molecule as OFFMolecule

from openfe.protocols.openmm_rfe._rfe_utils import compute
from openfe.protocols.openmm_utils import (
system_validation, settings_validation, system_creation,
charge_generation,
charge_generation, omm_compute
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -623,8 +622,11 @@ def run(self, *, dry=False, verbose=True,
)

# 10. Get platform
platform = compute.get_openmm_platform(
protocol_settings.engine_settings.compute_platform
restrict_cpu = forcefield_settings.nonbonded_method.lower() == 'nocutoff'
platform = omm_compute.get_openmm_platform(
platform_name=protocol_settings.engine_settings.compute_platform,
gpu_device_index=protocol_settings.engine_settings.gpu_device_index,
restrict_cpu_count=restrict_cpu
)

# 11. Set the integrator
Expand Down
1 change: 0 additions & 1 deletion openfe/protocols/openmm_rfe/_rfe_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from . import (
compute,
lambdaprotocol,
multistate,
relative,
Expand Down
12 changes: 8 additions & 4 deletions openfe/protocols/openmm_rfe/equil_rfe_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
)
from ..openmm_utils import (
system_validation, settings_validation, system_creation,
multistate_analysis, charge_generation
multistate_analysis, charge_generation, omm_compute,
)
from . import _rfe_utils
from ...utils import without_oechem_backend, log_system_probe
Expand Down Expand Up @@ -933,9 +933,13 @@ def run(self, *, dry=False, verbose=True,
bfactors=bfactors,
)

# 10. Get platform
platform = _rfe_utils.compute.get_openmm_platform(
protocol_settings.engine_settings.compute_platform
# 10. Get compute platform
# restrict to a single CPU if running vacuum
restrict_cpu = forcefield_settings.nonbonded_method.lower() == 'nocutoff'
platform = omm_compute.get_openmm_platform(
platform_name=protocol_settings.engine_settings.compute_platform,
gpu_device_index=protocol_settings.engine_settings.gpu_device_index,
restrict_cpu_count=restrict_cpu
)

# 11. Set the integrator
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
# Adapted Perses' perses.app.setup_relative_calculation.get_openmm_platform
from typing import Optional
import warnings
import logging
import os


logger = logging.getLogger(__name__)


def get_openmm_platform(platform_name=None):
def get_openmm_platform(
platform_name: Optional[str] = None,
gpu_device_index: Optional[list[int]] = None,
restrict_cpu_count: bool = False
):
"""
Return OpenMM's platform object based on given name. Setting to mixed
precision if using CUDA or OpenCL.
Parameters
----------
platform_name : str, optional, default=None
platform_name : Optional[str]
String with the platform name. If None, it will use the fastest
platform supporting mixed precision.
Default ``None``.
gpu_device_index : Optional[list[str]]
GPU device index selection. If ``None`` the default OpenMM
GPU selection will be used.
See the `OpenMM platform properties documentation <http://docs.openmm.org/latest/userguide/library/04_platform_specifics.html>`_
for more details.
Default ``None``.
restrict_cpu_count : bool
Optional hint to restrict the CPU count to 1 when
``platform_name`` is CPU. This allows Protocols to ensure
that no large performance in cases like vacuum simulations.
Returns
-------
Expand All @@ -44,16 +61,23 @@ def get_openmm_platform(platform_name=None):
# Set precision and properties
name = platform.getName()
if name in ['CUDA', 'OpenCL']:
platform.setPropertyDefaultValue(
'Precision', 'mixed')
platform.setPropertyDefaultValue('Precision', 'mixed')
if gpu_device_index is not None:
index_list = ','.join(str(i) for i in gpu_device_index)
platform.setPropertyDefaultValue('DeviceIndex', index_list)

if name == 'CUDA':
platform.setPropertyDefaultValue(
'DeterministicForces', 'true')

if name != 'CUDA':
wmsg = (f"Non-GPU platform selected: {name}, this may significantly "
wmsg = (f"Non-CUDA platform selected: {name}, this may significantly "
"impact simulation performance")
warnings.warn(wmsg)
logging.warning(wmsg)

if name == 'CPU' and restrict_cpu_count:
threads = os.getenv("OPENMM_CPU_THREADS", '1')
platform.setPropertyDefaultValue('Threads', threads)

return platform
12 changes: 12 additions & 0 deletions openfe/protocols/openmm_utils/omm_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,18 @@ class OpenMMEngineSettings(SettingsBaseModel):
OpenMM compute platform to perform MD integration with. If ``None``, will
choose fastest available platform. Default ``None``.
"""
gpu_device_index: Optional[list[int]] = None
"""
List of integer indices for the GPU device to select when
``compute_platform`` is either set to ``CUDA`` or ``OpenCL``.
If ``None``, the default OpenMM GPU selection behaviour is used.
See the `OpenMM platform properties documentation <http://docs.openmm.org/latest/userguide/library/04_platform_specifics.html>`_
for more details.
Default ``None``.
"""


class IntegratorSettings(SettingsBaseModel):
Expand Down
6 changes: 3 additions & 3 deletions openfe/tests/protocols/test_openmm_equil_rfe_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@
from openfe.protocols.openmm_rfe.equil_rfe_methods import (
_validate_alchemical_components, _get_alchemical_charge_difference
)
from openfe.protocols.openmm_utils import system_creation
from openfe.protocols.openmm_utils import system_creation, omm_compute
from openfe.protocols.openmm_utils.charge_generation import (
HAS_NAGL, HAS_OPENEYE, HAS_ESPALOMA
)


def test_compute_platform_warn():
with pytest.warns(UserWarning, match="Non-GPU platform selected: CPU"):
openmm_rfe._rfe_utils.compute.get_openmm_platform('CPU')
with pytest.warns(UserWarning, match="Non-CUDA platform selected: CPU"):
omm_compute.get_openmm_platform('CPU')


def test_append_topology(benzene_complex_system, toluene_complex_system):
Expand Down
11 changes: 9 additions & 2 deletions openfe/tests/protocols/test_rfe_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,20 @@ def instance(self):

class TestRelativeHybridTopologyProtocol(GufeTokenizableTestsMixin):
cls = openmm_rfe.RelativeHybridTopologyProtocol
key = "RelativeHybridTopologyProtocol-fbc7c8ac0f58808ad4430a155453932f"
repr = f"<{key}>"
key = None
repr = "<RelativeHybridTopologyProtocol-"

@pytest.fixture()
def instance(self, protocol):
return protocol

def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)


class TestRelativeHybridTopologyProtocolUnit(GufeTokenizableTestsMixin):
cls = openmm_rfe.RelativeHybridTopologyProtocolUnit
Expand Down
22 changes: 18 additions & 4 deletions openfe/tests/protocols/test_solvation_afe_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,20 @@ def protocol_result(afe_solv_transformation_json):

class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocol
key = "AbsoluteSolvationProtocol-36e2e292503864aac09e2d5066f24be1"
repr = f"<{key}>"
key = None
repr = "AbsoluteSolvationProtocol-"

@pytest.fixture()
def instance(self, protocol):
return protocol

def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)


class TestAbsoluteSolvationSolventUnit(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationSolventUnit
Expand Down Expand Up @@ -93,9 +100,16 @@ def test_repr(self, instance):

class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocolResult
key = "AbsoluteSolvationProtocolResult-7f80c1cf5a526bde45d385cee7352428"
repr = f"<{key}>"
key = None
repr = "AbsoluteSolvationProtocolResult-"

@pytest.fixture()
def instance(self, protocol_result):
return protocol_result

def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
2 changes: 1 addition & 1 deletion openfe/utils/handle_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _create_new_dataset(filename: Path, n_atoms: int,
AMBER Conventions compliant NetCDF dataset to store information
contained in MultiState reporter generated NetCDF file.
"""
ncfile = nc.Dataset(filename, 'w', format='NETCDF3_64BIT')
ncfile = nc.Dataset(filename, 'w', format='NETCDF3_64BIT_OFFSET')
ncfile.Conventions = 'AMBER'
ncfile.ConventionVersion = "1.0"
ncfile.application = "openfe"
Expand Down

0 comments on commit 72d623a

Please sign in to comment.