From e4b70e74b4a3e8f5854e530bc12d05c882aa40b6 Mon Sep 17 00:00:00 2001 From: Travis Driver Date: Fri, 11 Oct 2024 17:51:26 -0400 Subject: [PATCH] Added RoMa --- .gitignore | 1 + .gitmodules | 3 + gtsfm/configs/correspondence/loftr.yaml | 6 +- gtsfm/configs/correspondence/roma.yaml | 9 ++ .../image_correspondence_generator.py | 26 +++--- gtsfm/frontend/matcher/roma.py | 87 +++++++++++++++++++ thirdparty/LightGlue | 2 +- thirdparty/RoMa | 1 + 8 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 gtsfm/configs/correspondence/roma.yaml create mode 100644 gtsfm/frontend/matcher/roma.py create mode 160000 thirdparty/RoMa diff --git a/.gitignore b/.gitignore index 50e11fe25..6c7a153c2 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,7 @@ output/ # Cache dir cache/ *.mat +*.tsv # Dev notebooks notebooks/ diff --git a/.gitmodules b/.gitmodules index 29d1b4816..e4f30bb80 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "thirdparty/LightGlue"] path = thirdparty/LightGlue url = https://github.com/cvg/LightGlue.git +[submodule "thirdparty/RoMa"] + path = thirdparty/RoMa + url = https://github.com/Parskatt/RoMa.git diff --git a/gtsfm/configs/correspondence/loftr.yaml b/gtsfm/configs/correspondence/loftr.yaml index 1f363aa2f..bfe5ba5be 100644 --- a/gtsfm/configs/correspondence/loftr.yaml +++ b/gtsfm/configs/correspondence/loftr.yaml @@ -3,6 +3,8 @@ CorrespondenceGenerator: matcher: _target_: gtsfm.frontend.cacher.image_matcher_cacher.ImageMatcherCacher - matcher_obj: + matcher_obj: _target_: gtsfm.frontend.matcher.loftr.LOFTR - \ No newline at end of file + + aggregator: + _target_: gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup.KeypointAggregatorDedup diff --git a/gtsfm/configs/correspondence/roma.yaml b/gtsfm/configs/correspondence/roma.yaml new file mode 100644 index 000000000..8b4795740 --- /dev/null +++ b/gtsfm/configs/correspondence/roma.yaml @@ -0,0 +1,9 @@ +CorrespondenceGenerator: + _target_: gtsfm.frontend.correspondence_generator.image_correspondence_generator.ImageCorrespondenceGenerator + + matcher: + _target_: gtsfm.frontend.matcher.roma.RoMa + + aggregator: + _target_: gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup.KeypointAggregatorDedup + nms_merge_radius: 1e-4 diff --git a/gtsfm/frontend/correspondence_generator/image_correspondence_generator.py b/gtsfm/frontend/correspondence_generator/image_correspondence_generator.py index 5617ce837..224be7f23 100644 --- a/gtsfm/frontend/correspondence_generator/image_correspondence_generator.py +++ b/gtsfm/frontend/correspondence_generator/image_correspondence_generator.py @@ -4,21 +4,18 @@ """ from typing import Any, Dict, List, Optional, Tuple -from dask.distributed import Client, Future import numpy as np - +from dask.distributed import Client, Future from gtsfm.common.image import Image from gtsfm.common.keypoints import Keypoints from gtsfm.common.pose_prior import PosePrior from gtsfm.common.types import CALIBRATION_TYPE, CAMERA_TYPE -from gtsfm.frontend.correspondence_generator.correspondence_generator_base import CorrespondenceGeneratorBase -from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_base import KeypointAggregatorBase -from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup import ( - KeypointAggregatorDedup, -) -from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_unique import ( - KeypointAggregatorUnique, -) +from gtsfm.frontend.correspondence_generator.correspondence_generator_base import \ + CorrespondenceGeneratorBase +from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_base import \ + KeypointAggregatorBase +from gtsfm.frontend.correspondence_generator.keypoint_aggregator.keypoint_aggregator_dedup import \ + KeypointAggregatorDedup from gtsfm.frontend.matcher.image_matcher_base import ImageMatcherBase from gtsfm.two_view_estimator import TWO_VIEW_OUTPUT, TwoViewEstimator @@ -26,17 +23,16 @@ class ImageCorrespondenceGenerator(CorrespondenceGeneratorBase): """Pair-wise direct matching of images (e.g. transformer-based).""" - def __init__(self, matcher: ImageMatcherBase, deduplicate: bool = True) -> None: + def __init__( + self, matcher: ImageMatcherBase, aggregator: KeypointAggregatorBase = KeypointAggregatorDedup() + ) -> None: """ Args: matcher: Matcher to use. deduplicate: Whether to de-duplicate with a single image the detections received from each image pair. """ self._matcher = matcher - - self._aggregator: KeypointAggregatorBase = ( - KeypointAggregatorDedup() if deduplicate else KeypointAggregatorUnique() - ) + self._aggregator = aggregator def __repr__(self) -> str: return f""" diff --git a/gtsfm/frontend/matcher/roma.py b/gtsfm/frontend/matcher/roma.py new file mode 100644 index 000000000..4a230de9b --- /dev/null +++ b/gtsfm/frontend/matcher/roma.py @@ -0,0 +1,87 @@ +"""RoMa image matcher. + +The network was proposed in "RoMa: Revisiting Robust Losses for Dense Feature Matching". + +References: +- https://arxiv.org/html/2305.15404v2 + +Authors: Travis Driver +""" +from typing import Tuple + +import numpy as np +import PIL +import torch +from gtsfm.common.image import Image +from gtsfm.common.keypoints import Keypoints +from gtsfm.frontend.matcher.image_matcher_base import ImageMatcherBase +from romatch import roma_indoor, roma_outdoor + + +class RoMa(ImageMatcherBase): + """RoMa image matcher.""" + + def __init__( + self, + use_cuda: bool = True, + min_confidence: float = 0.1, + max_keypoints: int = 8000, + use_indoor_model: bool = False, + ) -> None: + """Initialize the matcher. + + Args: + use_outdoor_model (optional): use the outdoor pretrained model. Defaults to True. + use_cuda (optional): use CUDA for inference on GPU. Defaults to True. + min_confidence(optional): Minimum confidence required for matches. Defaults to 0.95. + upsample_res: resolution of upsampled warp and certainty maps. Stored as (H, W). + """ + super().__init__() + self._min_confidence = min_confidence + self._max_keypoints = max_keypoints + + # Initialize model. + self._device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu") + if use_indoor_model: + self._matcher = roma_indoor(self._device).eval() + else: + self._matcher = roma_outdoor(self._device).eval() + + def match(self, image_i1: Image, image_i2: Image) -> Tuple[Keypoints, Keypoints]: + """Identify feature matches across two images. + + Note: the matcher will run out of memory for large image sizes + + Args: + image_i1: first input image of pair. + image_i2: second input image of pair. + + Returns: + Keypoints from image 1 (N keypoints will exist). + Corresponding keypoints from image 2 (there will also be N keypoints). These represent feature matches. + """ + # Compute dense warp and certainty maps. + with torch.no_grad(): + im1 = PIL.Image.fromarray(image_i1.value_array).convert("RGB") + im2 = PIL.Image.fromarray(image_i2.value_array).convert("RGB") + warp, certainty = self._matcher.match(im1, im2, device=self._device) + + # Sample keypoints and correspondences from warp. + H1, W1 = image_i1.shape[:2] + H2, W2 = image_i2.shape[:2] + match, certs = self._matcher.sample(warp, certainty, num=self._max_keypoints) + match = match[certs > self._min_confidence] + mkpts1, mkpts2 = self._matcher.to_pixel_coordinates(match, H1, W1, H2, W2) + + # Convert to GTSfM keypoints and filter by mask. + keypoints_i1 = Keypoints(coordinates=mkpts1.cpu().numpy()) + keypoints_i2 = Keypoints(coordinates=mkpts2.cpu().numpy()) + valid_ind = np.arange(len(keypoints_i1)) + if image_i1.mask is not None: + _, valid_ind_i1 = keypoints_i1.filter_by_mask(image_i1.mask) + valid_ind = np.intersect1d(valid_ind, valid_ind_i1) + if image_i2.mask is not None: + _, valid_ind_i2 = keypoints_i2.filter_by_mask(image_i2.mask) + valid_ind = np.intersect1d(valid_ind, valid_ind_i2) + + return keypoints_i1.extract_indices(valid_ind), keypoints_i2.extract_indices(valid_ind) diff --git a/thirdparty/LightGlue b/thirdparty/LightGlue index fe7fb4fa0..29f3e449e 160000 --- a/thirdparty/LightGlue +++ b/thirdparty/LightGlue @@ -1 +1 @@ -Subproject commit fe7fb4fa0cffec65e33bf4c2f62a863d5b03433a +Subproject commit 29f3e449efa1994758b8a16299d2816028dca65b diff --git a/thirdparty/RoMa b/thirdparty/RoMa new file mode 160000 index 000000000..fa66dbc3f --- /dev/null +++ b/thirdparty/RoMa @@ -0,0 +1 @@ +Subproject commit fa66dbc3fdbe0cdff9b7b9c6f2a7015f026539fb