-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathplace_rec_DINO_finetuned.py
119 lines (90 loc) · 5.54 KB
/
place_rec_DINO_finetuned.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import cv2
import os
import matplotlib.pyplot as plt
import matplotlib
from natsort import natsorted
import networkx as nx
import h5py
from glob import glob
import pickle
from importlib import reload
from tqdm import tqdm
import numpy as np
import argparse
import func_vpr
from place_rec_global_config import datasets, workdir_data
if __name__=="__main__":
DINONV_extraction = True #False # # SegVLAD finetuned
DINOSALAD_extraction = False # SALAD
print(f"DINONV_extraction: {DINONV_extraction}, DINOSALAD_extraction: {DINOSALAD_extraction}")
# Be careful: Even in cases when SAM_extraction is True, you may want to use full resolution. Depends on what the actual image resolution is. If it's too huge, you want to half it, else use full resolution.
# mask_full_resolution = False #DINO always full resolution
parser = argparse.ArgumentParser(description='SAM/DINO/FastSAM extraction for Any Dataset. See place_rec_global_config.py to see how to give arguments.')
parser.add_argument('--dataset', required=True, help='Dataset name') # baidu, pitts etc
args = parser.parse_args()
# Load dataset and experiment configurations
dataset_config = datasets.get(args.dataset, {})
if not dataset_config:
raise ValueError(f"Dataset '{args.dataset}' not found in configuration.")
print(dataset_config)
cfg = dataset_config['cfg']
# mask width and height: half if mask_full_resolution is False, else True
width_DINO, height_DINO = cfg['desired_width'], cfg['desired_height']
print(f"IMPORTANT: The dimensions being used for DINO extraction are {width_DINO}x{height_DINO} pixels.")
# if args.dataset == "pitts" or args.dataset.startswith("msls") or args.dataset == "tokyo247":
workdir = f'{workdir_data}/{args.dataset}/out'
os.makedirs(workdir, exist_ok=True)
save_path_results = f"{workdir}/results/"
ims_sidx, ims_eidx, ims_step = 0, None, 1
dataPath1_r = f"{workdir_data}/{args.dataset}/{dataset_config['data_subpath1_r']}/"
dataPath2_q = f"{workdir_data}/{args.dataset}/{dataset_config['data_subpath2_q']}/"
if DINONV_extraction:
dino_nv_checkpoint = f"{workdir_data}/models/DnV2_NV/last.ckpt"
# sam_checkpoint = f"{workdir_data}/SegmentsMap_data/models/FastSAM/FastSAM-x.pt"
list_all = [
{"dataPath": dataPath1_r, "h5FullPathDINO": f"{workdir}/{args.dataset}_r_dinoNV_{width_DINO}.h5"},
{"dataPath": dataPath2_q, "h5FullPathDINO": f"{workdir}/{args.dataset}_q_dinoNV_{width_DINO}.h5"}]
elif DINOSALAD_extraction:
dino_salad_checkpoint = f"{workdir_data}/models/dino_salad.ckpt"
# sam_checkpoint = f"{workdir_data}/SegmentsMap_data/models/segment-anything/sam_vit_h_4b8939.pth"
list_all = [
{"dataPath": dataPath1_r, "h5FullPathDINO": f"{workdir}/{args.dataset}_r_dinoSALAD_{width_DINO}.h5"},
{"dataPath": dataPath2_q, "h5FullPathDINO": f"{workdir}/{args.dataset}_q_dinoSALAD_{width_DINO}.h5"} ]
# EXTRACTION STARTS:
if DINONV_extraction:
for iter_dict in list_all:
# skip r and only do q
# for iter_dict in list_all[1:]:
dataPath = iter_dict["dataPath"]
ims = natsorted(os.listdir(f'{dataPath}'))
ims = ims[ims_sidx:ims_eidx][::ims_step]
h5FullPathDINONV = iter_dict["h5FullPathDINO"]
cfg_dino = { "desired_width": width_DINO, "desired_height": height_DINO, "detect": 'dino', "use_sam": True, "class_threshold": 0.9, \
"desired_feature": 0, "query_type": 'text', "sort_by": 'area', "use_16bit": False, "use_cuda": True,\
"dino_strides": 4, "use_traced_model": False,
"rmin":0, "DAStoreFull":False, "dinov2": True, "wrap":False, "resize": True} # robohop specifc params
print("DINONV extraction started...")
# dino = func.loadDINO(cfg_dino, device="cuda")
# func.process_dino_ft_to_h5(h5FullPathDINONV,cfg_dino,ims,dino,dataDir=dataPath)
backbone = func_vpr.loadDINONV(cfg_dino, dino_nv_checkpoint,device="cuda",feat_type="backbone")
func_vpr.process_DINONV(backbone,ims,cfg_dino,h5FullPathDINONV,dataPath)
del backbone
print(f"\n \n DINONV EXTRACTED DONE at path: {h5FullPathDINONV} \n \n ")
# print("\n \n DINONV EXTRACTED DONE \n \n NEXT DINOSALAD \n \n")
if DINOSALAD_extraction:
# only r but not q
# for iter_dict in list_all[:1]:
for iter_dict in list_all:
dataPath = iter_dict["dataPath"]
ims = natsorted(os.listdir(f'{dataPath}'))
ims = ims[ims_sidx:ims_eidx][::ims_step]
h5FullPathDINOSALAD = iter_dict["h5FullPathDINO"]
cfg_dino = { "desired_width": width_DINO, "desired_height": height_DINO, "detect": 'dino', "use_sam": True, "class_threshold": 0.9, \
"desired_feature": 0, "query_type": 'text', "sort_by": 'area', "use_16bit": False, "use_cuda": True,\
"dino_strides": 4, "use_traced_model": False,
"rmin":0, "DAStoreFull":False, "dinov2": True, "wrap":False, "resize": True} # robohop specifc params
print("DINO extraction started...")
dino = func_vpr.loadDINOSALAD(cfg_dino, ckpt_path=dino_salad_checkpoint, device="cuda", feat_type="full")
func_vpr.process_dino_salad_ft_to_h5(h5FullPathDINOSALAD,cfg_dino,ims,dino,dataDir=dataPath, device="cuda", feat_type="full", feat_return='f')
del dino
print("\n \n DINOSALAD EXTRACTED DONE \n \n ")