Skip to content

Commit

Permalink
[FIX] better device allocation for model
Browse files Browse the repository at this point in the history
  • Loading branch information
BenCretois committed Sep 4, 2023
1 parent fa64650 commit 805784b
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@
format='%(asctime)s - %(levelname)s - %(message)s')

def initModel(model_path, device):
model = torch.load(model_path)
model = torch.load(model_path, map_location=torch.device(device))
model.eval()

return model.to(device)
return model

def compute_hr(array):

Expand Down

0 comments on commit 805784b

Please sign in to comment.