Skip to content

Commit

Permalink
https://github.com/Parskatt/DeDoDe/issues/30
Browse files Browse the repository at this point in the history
  • Loading branch information
Parskatt committed Apr 22, 2024
1 parent 55cb056 commit cd5a487
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DeDoDe/model_zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dedode_models import dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G
from .dedode_models import dedode_detector_S, dedode_detector_B, dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G


58 changes: 57 additions & 1 deletion DeDoDe/model_zoo/dedode_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,65 @@
from DeDoDe.detectors.dedode_detector import DeDoDeDetector
from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor
from DeDoDe.decoder import ConvRefiner, Decoder
from DeDoDe.encoder import VGG19, VGG, VGG_DINOv2
from DeDoDe.encoder import VGG11, VGG, VGG_DINOv2
from DeDoDe.utils import get_best_device

def dedode_detector_S(device = get_best_device(), weights = None, remove_borders = False):
if weights is None:
weights = torch.hub.load_state_dict_from_url("https://github.com/Parskatt/DeDoDe/releases/download/v2/dedode_detector_S_v2.pth", map_location = device)
NUM_PROTOTYPES = 1
residual = True
hidden_blocks = 4
amp_dtype = torch.float16#torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
amp = True
conv_refiner = nn.ModuleDict(
{
"8": ConvRefiner(
512,
256,
128 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,
),
"4": ConvRefiner(
256+128,
128,
64 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,

),
"2": ConvRefiner(
128+64,
64,
32 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,

),
"1": ConvRefiner(
64 + 32,
32,
1 + NUM_PROTOTYPES,
hidden_blocks = hidden_blocks,
residual = residual,
amp = amp,
amp_dtype = amp_dtype,
),
}
)
encoder = VGG11(pretrained = False, amp = amp, amp_dtype = amp_dtype)
decoder = Decoder(conv_refiner)
model = DeDoDeDetector(encoder = encoder, decoder = decoder, remove_borders = remove_borders).to(device)
if weights is not None:
model.load_state_dict(weights)
return model

def dedode_detector_B(device = get_best_device(), weights = None):
residual = True
Expand Down
40 changes: 40 additions & 0 deletions experiments/eval/eval_dedode_S-v2---B-v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
from argparse import ArgumentParser

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import ConcatDataset
import torch.nn as nn

from DeDoDe.train import train_k_steps
from DeDoDe.datasets.megadepth import MegadepthBuilder
from DeDoDe.descriptors.descriptor_loss import DescriptorLoss
from DeDoDe.checkpoint import CheckPoint
from DeDoDe.descriptors.dedode_descriptor import DeDoDeDescriptor
from DeDoDe.encoder import VGG
from DeDoDe.decoder import ConvRefiner, Decoder
from DeDoDe import dedode_detector_S, dedode_descriptor_B
from DeDoDe.benchmarks import MegaDepthPoseMNNBenchmark
#from DeDoDe import dedode_detector_L, dedode_descriptor_B
from DeDoDe.matchers.dual_softmax_matcher import DualSoftMaxMatcher
#from DeDoDe.matchers.soft_dual_softmax_matcher import SoftDualSoftMaxMatcher


from DeDoDe.utils import *
from PIL import Image
import cv2
import numpy as np


if __name__ == "__main__":
device = get_best_device()
detector = dedode_detector_S(weights = torch.load("dedode_detector_S_v2.pth", map_location = device))
descriptor = dedode_descriptor_B(weights = torch.load("dedode_descriptor_B.pth", map_location = device))
matcher = DualSoftMaxMatcher()

mega_1500 = MegaDepthPoseMNNBenchmark()
mega_1500.benchmark(
detector_model = detector,
descriptor_model = descriptor,
matcher_model = matcher)

0 comments on commit cd5a487

Please sign in to comment.