diff --git a/test/test_openmm_plugin.py b/test/test_openmm_plugin.py index eac47e5..cc9d3bb 100644 --- a/test/test_openmm_plugin.py +++ b/test/test_openmm_plugin.py @@ -32,8 +32,11 @@ def forward(self, x): para = mymodule(_d.clone()) from molearn.loss_functions import openmm_energy + device_name = "CPU" + if torch.cuda.is_available(): + device_name = "CUDA" openmmscore = openmm_energy( - data.mol, data.std, clamp=None, platform="CUDA" + data.mol, data.std, clamp=None, platform=device_name ) # xml_file = ['modified_amber_protein.xml',]) opt = torch.optim.SGD(para.parameters(), lr=0.0001) scores = []