Skip to content

Commit

Permalink
Adds safety flags to torch model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu committed Aug 24, 2024
1 parent 6ec74a3 commit 1ded99a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/poli/core/util/proteins/rasp/load_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def load_cavity_and_downstream_models(
best_cavity_model_path = RASP_DIR / "cavity_model_15.pt"
cavity_model_net = CavityModel(get_latent=True).to(DEVICE)
cavity_model_net.load_state_dict(
torch.load(f"{best_cavity_model_path}", map_location=DEVICE)
torch.load(f"{best_cavity_model_path}", map_location=DEVICE, weights_only=True)
)
cavity_model_net.eval()
ds_model_net = DownstreamModel().to(DEVICE)
Expand Down

0 comments on commit 1ded99a

Please sign in to comment.