From 089602a70f81c96b7f0f27408947f6aaddeae480 Mon Sep 17 00:00:00 2001 From: Jonas Eschmann Date: Mon, 4 Dec 2023 20:49:02 -0500 Subject: [PATCH] Fixing model loading on Apple Silicon (when DEVICE is set to "mps") --- demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo.py b/demo.py index 5abc1da8..a3236c03 100644 --- a/demo.py +++ b/demo.py @@ -41,7 +41,7 @@ def viz(img, flo): def demo(args): model = torch.nn.DataParallel(RAFT(args)) - model.load_state_dict(torch.load(args.model)) + model.load_state_dict(torch.load(args.model, map_location=torch.device(DEVICE))) model = model.module model.to(DEVICE)