-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathconvert_model.py
40 lines (34 loc) · 1.41 KB
/
convert_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import os
from Networks import netvlad, superpoint, superglue, ultrapoint
if __name__ == "__main__":
assert torch.cuda.is_available()
device = 'cuda'
if not os.path.exists("models/SuperPoint_300.pt"):
model = superpoint.SuperPoint(nms_radius=3, max_keypoints=300).eval().to(device)
scripted_module = torch.jit.script(model)
scripted_module.save("models/SuperPoint_300.pt")
print("SuperPoint_300 Converted")
else:
print("SuperPoint_300 Exist")
if not os.path.exists("models/SuperGlue_outdoor.pt"):
model = superglue.SuperGlue(weights='outdoor', sinkhorn_iterations=50).eval().to(device)
scripted_module = torch.jit.script(model)
scripted_module.save("models/SuperGlue_outdoor.pt")
print("SuperGlue_outdoor Converted")
else:
print("SuperGlue_outdoor Exist")
if not os.path.exists("models/NetVLAD.pt"):
model = netvlad.NetVLAD().eval().to(device)
scripted_module = torch.jit.script(model)
scripted_module.save("models/NetVLAD.pt")
print("NetVLAD Converted")
else:
print("NetVLAD Exist")
if not os.path.exists("models/UltraPoint.pt"):
model = ultrapoint.UltraPoint().eval().to(device)
scripted_module = torch.jit.script(model)
scripted_module.save("models/UltraPoint.pt")
print("UltraPoint Converted")
else:
print("UltraPoint Exist")