diff --git a/news/support_hip.rst b/news/support_hip.rst new file mode 100644 index 000000000..74fb4f69f --- /dev/null +++ b/news/support_hip.rst @@ -0,0 +1,23 @@ +**Added:** + +* HIP Platform Support + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/compute.py b/openfe/protocols/openmm_rfe/_rfe_utils/compute.py index b3bee28f6..d325b070c 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/compute.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/compute.py @@ -35,6 +35,7 @@ def get_openmm_platform(platform_name=None): 'cpu': 'CPU', 'opencl': 'OpenCL', 'cuda': 'CUDA', + 'hip': 'HIP', }[str(platform_name).lower()] except KeyError: pass @@ -43,14 +44,14 @@ def get_openmm_platform(platform_name=None): platform = Platform.getPlatformByName(platform_name) # Set precision and properties name = platform.getName() - if name in ['CUDA', 'OpenCL']: + if name in ['CUDA', 'OpenCL', 'HIP']: platform.setPropertyDefaultValue( 'Precision', 'mixed') - if name == 'CUDA': + if name in ['CUDA', 'HIP']: platform.setPropertyDefaultValue( 'DeterministicForces', 'true') - if name != 'CUDA': + if name not in ['CUDA', 'HIP']: wmsg = (f"Non-GPU platform selected: {name}, this may significantly " "impact simulation performance") warnings.warn(wmsg)