Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RoMa #812

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ output/
# Cache dir
cache/
*.mat
*.tsv

# Dev notebooks
notebooks/
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "thirdparty/LightGlue"]
path = thirdparty/LightGlue
url = https://github.com/cvg/LightGlue.git
[submodule "thirdparty/RoMa"]
path = thirdparty/RoMa
url = https://github.com/Parskatt/RoMa.git
6 changes: 4 additions & 2 deletions gtsfm/configs/correspondence/loftr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ CorrespondenceGenerator:

matcher:
_target_: gtsfm.frontend.cacher.image_matcher_cacher.ImageMatcherCacher
matcher_obj:
matcher_obj:
_target_: gtsfm.frontend.matcher.loftr.LOFTR


aggregator:
_target_: gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup.KeypointAggregatorDedup
9 changes: 9 additions & 0 deletions gtsfm/configs/correspondence/roma.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CorrespondenceGenerator:
_target_: gtsfm.frontend.correspondence_generator.image_correspondence_generator.ImageCorrespondenceGenerator

matcher:
_target_: gtsfm.frontend.matcher.roma.RoMa

aggregator:
_target_: gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup.KeypointAggregatorDedup
nms_merge_radius: 1e-4
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,35 @@
"""
from typing import Any, Dict, List, Optional, Tuple

from dask.distributed import Client, Future
import numpy as np

from dask.distributed import Client, Future
from gtsfm.common.image import Image
from gtsfm.common.keypoints import Keypoints
from gtsfm.common.pose_prior import PosePrior
from gtsfm.common.types import CALIBRATION_TYPE, CAMERA_TYPE
from gtsfm.frontend.correspondence_generator.correspondence_generator_base import CorrespondenceGeneratorBase
from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_base import KeypointAggregatorBase
from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup import (
KeypointAggregatorDedup,
)
from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_unique import (
KeypointAggregatorUnique,
)
from gtsfm.frontend.correspondence_generator.correspondence_generator_base import \
CorrespondenceGeneratorBase
from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_base import \
KeypointAggregatorBase
from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup import \
KeypointAggregatorDedup
from gtsfm.frontend.matcher.image_matcher_base import ImageMatcherBase
from gtsfm.two_view_estimator import TWO_VIEW_OUTPUT, TwoViewEstimator


class ImageCorrespondenceGenerator(CorrespondenceGeneratorBase):
"""Pair-wise direct matching of images (e.g. transformer-based)."""

def __init__(self, matcher: ImageMatcherBase, deduplicate: bool = True) -> None:
def __init__(
self, matcher: ImageMatcherBase, aggregator: KeypointAggregatorBase = KeypointAggregatorDedup()
) -> None:
"""
Args:
matcher: Matcher to use.
deduplicate: Whether to de-duplicate with a single image the detections received from each image pair.
"""
self._matcher = matcher

self._aggregator: KeypointAggregatorBase = (
KeypointAggregatorDedup() if deduplicate else KeypointAggregatorUnique()
)
self._aggregator = aggregator

def __repr__(self) -> str:
return f"""
Expand Down
87 changes: 87 additions & 0 deletions gtsfm/frontend/matcher/roma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""RoMa image matcher.

The network was proposed in "RoMa: Revisiting Robust Losses for Dense Feature Matching".

References:
- https://arxiv.org/html/2305.15404v2

Authors: Travis Driver
"""
from typing import Tuple

import numpy as np
import PIL
import torch
from gtsfm.common.image import Image
from gtsfm.common.keypoints import Keypoints
from gtsfm.frontend.matcher.image_matcher_base import ImageMatcherBase
from romatch import roma_indoor, roma_outdoor


class RoMa(ImageMatcherBase):
"""RoMa image matcher."""

def __init__(
self,
use_cuda: bool = True,
min_confidence: float = 0.1,
max_keypoints: int = 8000,
use_indoor_model: bool = False,
) -> None:
"""Initialize the matcher.

Args:
use_outdoor_model (optional): use the outdoor pretrained model. Defaults to True.
use_cuda (optional): use CUDA for inference on GPU. Defaults to True.
min_confidence(optional): Minimum confidence required for matches. Defaults to 0.95.
upsample_res: resolution of upsampled warp and certainty maps. Stored as (H, W).
"""
super().__init__()
self._min_confidence = min_confidence
self._max_keypoints = max_keypoints

# Initialize model.
self._device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
if use_indoor_model:
self._matcher = roma_indoor(self._device).eval()
else:
self._matcher = roma_outdoor(self._device).eval()

def match(self, image_i1: Image, image_i2: Image) -> Tuple[Keypoints, Keypoints]:
"""Identify feature matches across two images.

Note: the matcher will run out of memory for large image sizes

Args:
image_i1: first input image of pair.
image_i2: second input image of pair.

Returns:
Keypoints from image 1 (N keypoints will exist).
Corresponding keypoints from image 2 (there will also be N keypoints). These represent feature matches.
"""
# Compute dense warp and certainty maps.
with torch.no_grad():
im1 = PIL.Image.fromarray(image_i1.value_array).convert("RGB")
im2 = PIL.Image.fromarray(image_i2.value_array).convert("RGB")
warp, certainty = self._matcher.match(im1, im2, device=self._device)

# Sample keypoints and correspondences from warp.
H1, W1 = image_i1.shape[:2]
H2, W2 = image_i2.shape[:2]
match, certs = self._matcher.sample(warp, certainty, num=self._max_keypoints)
match = match[certs > self._min_confidence]
mkpts1, mkpts2 = self._matcher.to_pixel_coordinates(match, H1, W1, H2, W2)

# Convert to GTSfM keypoints and filter by mask.
keypoints_i1 = Keypoints(coordinates=mkpts1.cpu().numpy())
keypoints_i2 = Keypoints(coordinates=mkpts2.cpu().numpy())
valid_ind = np.arange(len(keypoints_i1))
if image_i1.mask is not None:
_, valid_ind_i1 = keypoints_i1.filter_by_mask(image_i1.mask)
valid_ind = np.intersect1d(valid_ind, valid_ind_i1)
if image_i2.mask is not None:
_, valid_ind_i2 = keypoints_i2.filter_by_mask(image_i2.mask)
valid_ind = np.intersect1d(valid_ind, valid_ind_i2)

return keypoints_i1.extract_indices(valid_ind), keypoints_i2.extract_indices(valid_ind)
1 change: 1 addition & 0 deletions thirdparty/RoMa
Submodule RoMa added at fa66db
Loading