From cd5a4875bbc53fbb2d49b08a45a922895081694a Mon Sep 17 00:00:00 2001 From: Johan Edstedt Date: Mon, 22 Apr 2024 15:40:25 +0200 Subject: [PATCH] https://github.com/Parskatt/DeDoDe/issues/30 --- DeDoDe/model_zoo/__init__.py | 2 +- DeDoDe/model_zoo/dedode_models.py | 58 ++++++++++++++++++++- experiments/eval/eval_dedode_S-v2---B-v1.py | 40 ++++++++++++++ 3 files changed, 98 insertions(+), 2 deletions(-) create mode 100644 experiments/eval/eval_dedode_S-v2---B-v1.py diff --git a/DeDoDe/model_zoo/__init__.py b/DeDoDe/model_zoo/__init__.py index 0775d43..11a9edc 100644 --- a/DeDoDe/model_zoo/__init__.py +++ b/DeDoDe/model_zoo/__init__.py @@ -1,3 +1,3 @@ -from .dedode_models import dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G +from .dedode_models import dedode_detector_S, dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G \ No newline at end of file diff --git a/DeDoDe/model_zoo/dedode_models.py b/DeDoDe/model_zoo/dedode_models.py index deac312..9bca8cc 100644 --- a/DeDoDe/model_zoo/dedode_models.py +++ b/DeDoDe/model_zoo/dedode_models.py @@ -4,9 +4,65 @@ from DeDoDe.detectors.dedode_detector import DeDoDeDetector from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor from DeDoDe.decoder import ConvRefiner, Decoder -from DeDoDe.encoder import VGG19, VGG, VGG_DINOv2 +from DeDoDe.encoder import VGG11, VGG, VGG_DINOv2 from DeDoDe.utils import get_best_device +def dedode_detector_S(device = get_best_device(), weights = None, remove_borders = False): + if weights is None: + weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_S_v2.pth", map_location = device) + NUM_PROTOTYPES = 1 + residual = True + hidden_blocks = 4 + amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + amp = True + conv_refiner = nn.ModuleDict( + { + "8": ConvRefiner( + 512, + 256, + 128 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + "4": ConvRefiner( + 256+128, + 128, + 64 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "2": ConvRefiner( + 128+64, + 64, + 32 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + + ), + "1": ConvRefiner( + 64 + 32, + 32, + 1 + NUM_PROTOTYPES, + hidden_blocks = hidden_blocks, + residual = residual, + amp = amp, + amp_dtype = amp_dtype, + ), + } + ) + encoder = VGG11(pretrained = False, amp = amp, amp_dtype = amp_dtype) + decoder = Decoder(conv_refiner) + model = DeDoDeDetector(encoder = encoder, decoder = decoder, remove_borders = remove_borders).to(device) + if weights is not None: + model.load_state_dict(weights) + return model def dedode_detector_B(device = get_best_device(), weights = None): residual = True diff --git a/experiments/eval/eval_dedode_S-v2---B-v1.py b/experiments/eval/eval_dedode_S-v2---B-v1.py new file mode 100644 index 0000000..e6ff5b1 --- /dev/null +++ b/experiments/eval/eval_dedode_S-v2---B-v1.py @@ -0,0 +1,40 @@ +import os +from argparse import ArgumentParser + +import torch +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.utils.data import ConcatDataset +import torch.nn as nn + +from DeDoDe.train import train_k_steps +from DeDoDe.datasets.megadepth import MegadepthBuilder +from DeDoDe.descriptors.descriptor_loss import DescriptorLoss +from DeDoDe.checkpoint import CheckPoint +from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor +from DeDoDe.encoder import VGG +from DeDoDe.decoder import ConvRefiner, Decoder +from DeDoDe import dedode_detector_S, dedode_descriptor_B +from DeDoDe.benchmarks import MegaDepthPoseMNNBenchmark +#from DeDoDe import dedode_detector_L, dedode_descriptor_B +from DeDoDe.matchers.dual_softmax_matcher import DualSoftMaxMatcher +#from DeDoDe.matchers.soft_dual_softmax_matcher import SoftDualSoftMaxMatcher + + +from DeDoDe.utils import * +from PIL import Image +import cv2 +import numpy as np + + +if __name__ == "__main__": + device = get_best_device() + detector = dedode_detector_S(weights = torch.load("dedode_detector_S_v2.pth", map_location = device)) + descriptor = dedode_descriptor_B(weights = torch.load("dedode_descriptor_B.pth", map_location = device)) + matcher = DualSoftMaxMatcher() + + mega_1500 = MegaDepthPoseMNNBenchmark() + mega_1500.benchmark( + detector_model = detector, + descriptor_model = descriptor, + matcher_model = matcher) \ No newline at end of file