-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsuperpoint_extraction_airloc.py
97 lines (80 loc) · 3.16 KB
/
superpoint_extraction_airloc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import argparse
import yaml
from tqdm import tqdm
import torch
import cv2
from model.build_model import build_superpoint_model
from model.inference import superpoint_inference
def find(lst, key, value):
ind = []
id = []
for i, dic in enumerate(lst):
if value in dic[key][0]:
ind.append(i)
id.append(lst[i]['id'])
return ind, id
def inference(configs):
## data cofig
data_config = configs['data']
## superpoint model config
superpoint_model_config = configs['model']['superpoint']
detection_threshold = superpoint_model_config['detection_threshold']
## others
configs['num_gpu'] = [0]
configs['public_model'] = 0
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
base_dir = configs['img_data_path']
datasets = configs['datasets']
# model
superpoint_model = build_superpoint_model(configs)
for dataset in datasets:
print("Dataset Name = ", dataset)
dataset_path = os.path.join(base_dir, dataset)
for scene in os.listdir(dataset_path):
print("Scene Name = ", scene)
scene_path = os.path.join(dataset_path, scene)
for room_name in os.listdir(os.path.join(scene_path,"rooms")):
print("Room Name = " , room_name)
room_path = os.path.join(scene_path, "rooms", room_name)
raw_data_folder = os.path.join(room_path,"raw_data/")
points_dir = os.path.join(room_path,"points")
if not os.path.isdir(points_dir):
os.mkdir(points_dir)
for img_name in tqdm(os.listdir(raw_data_folder)) :
if img_name.endswith("rgb.png"):
img_path = os.path.join(raw_data_folder,img_name)
data = {}
src = cv2.imread(img_path)
image = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
image = cv2.merge([image, image, image])
image = torch.from_numpy(image).type(torch.float32).to(device)
image = image.permute(2,0,1)
image /= 255
data['image'] = [image]
data['image_name'] = [str(img_name.split('_')[0])]
with torch.no_grad():
result = superpoint_inference(superpoint_model, data, data_config, detection_threshold, points_dir)
def main():
parser = argparse.ArgumentParser(description="SuperPoint Feature Extraction")
parser.add_argument(
"-c", "--config_file",
dest = "config_file",
type = str,
default = ""
)
parser.add_argument(
"-g", "--gpu",
dest = "gpu",
type = int,
default = 1
)
args = parser.parse_args()
config_file = args.config_file
f = open(config_file, 'r', encoding='utf-8')
configs = f.read()
configs = yaml.safe_load(configs)
configs['use_gpu'] = args.gpu
inference(configs)
if __name__ == "__main__":
main()