Skip to content

Commit

Permalink
updating device name from CUDA to what's available
Browse files Browse the repository at this point in the history
  • Loading branch information
gwirn committed Jun 5, 2024
1 parent 0ab4784 commit 910e49e
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion test/test_openmm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ def forward(self, x):
para = mymodule(_d.clone())
from molearn.loss_functions import openmm_energy

device_name = "CPU"
if torch.cuda.is_available():
device_name = "CUDA"
openmmscore = openmm_energy(
data.mol, data.std, clamp=None, platform="CUDA"
data.mol, data.std, clamp=None, platform=device_name
) # xml_file = ['modified_amber_protein.xml',])
opt = torch.optim.SGD(para.parameters(), lr=0.0001)
scores = []
Expand Down

0 comments on commit 910e49e

Please sign in to comment.